Skip to content

Commit ee6cc04

Browse files
[Feature] Add MinkUNet segmentor (#2294)
* add cylindrical voxelization & voxel feature encoder * add cylindrical voxelization & voxel feature encoder * add voxel-wise label & voxelization UT * fix vfe * fix vfe UT * rename voxel encoder & add more test case * fix type hint * temporarily refactoring mmcv's voxelize and dynamic in mmdet3d for data_preprocesser * _forward * del checkpoints * add if tp * add predict * fix vfe init bug & fix UT * add grid_size & move voxelization code * fix import bug * keep radian to follow origin * add doc string * fix type hint * add minkunet voxelization and loss function * fix data * init train * fix sparsetensor typehint * rename dir * fix data config * fix data config * fix batch_size & replace dynamic_scatter * fix conflicts 2 * fix conflicts on s_70 * Alignment of the original implementation * rename config * add worker_init_fn_hook * remove test_config & worker hook * add UT * fix polarmix UT * add seed for cr0p5 * format * rename SemanticKittiDataset * add platte & fix visual bug * add platte & fix data info bug * fix ut * fix semantic_kitti ut * fix docstring * fix config name * rename layer * fix doc string * fix review * remove filter data * fix coors typo * fix ut * pred in segmentor * fix get voxel seg * resolve comments
1 parent be2029d commit ee6cc04

File tree

15 files changed

+648
-3
lines changed

15 files changed

+648
-3
lines changed

configs/_base_/models/minkunet.py

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,29 @@
1+
model = dict(
2+
type='MinkUNet',
3+
data_preprocessor=dict(
4+
type='Det3DDataPreprocessor',
5+
voxel=True,
6+
voxel_type='minkunet',
7+
voxel_layer=dict(
8+
max_num_points=-1,
9+
point_cloud_range=[-100, -100, -20, 100, 100, 20],
10+
voxel_size=[0.05, 0.05, 0.05],
11+
max_voxels=(-1, -1)),
12+
),
13+
backbone=dict(
14+
type='MinkUNetBackbone',
15+
in_channels=4,
16+
base_channels=32,
17+
encoder_channels=[32, 64, 128, 256],
18+
decoder_channels=[256, 128, 96, 96],
19+
num_stages=4,
20+
init_cfg=None),
21+
decode_head=dict(
22+
type='MinkUNetHead',
23+
channels=96,
24+
num_classes=19,
25+
dropout_ratio=0,
26+
loss_decode=dict(type='mmdet.CrossEntropyLoss', avg_non_ignore=True),
27+
ignore_index=19),
28+
train_cfg=dict(),
29+
test_cfg=dict())
Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,13 @@
1+
_base_ = ['./minkunet_w32_8xb2-15e_semantickitti.py']
2+
3+
model = dict(
4+
backbone=dict(
5+
base_channels=16,
6+
encoder_channels=[16, 32, 64, 128],
7+
decoder_channels=[128, 64, 48, 48]),
8+
decode_head=dict(channels=48))
9+
10+
# NOTE: Due to TorchSparse backend, the model performance is relatively
11+
# dependent on random seeds, and if random seeds are not specified the
12+
# model performance will be different (± 1.5 mIoU).
13+
randomness = dict(seed=1588147245)
Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,8 @@
1+
_base_ = ['./minkunet_w32_8xb2-15e_semantickitti.py']
2+
3+
model = dict(
4+
backbone=dict(
5+
base_channels=20,
6+
encoder_channels=[20, 40, 81, 163],
7+
decoder_channels=[163, 81, 61, 61]),
8+
decode_head=dict(channels=61))
Lines changed: 54 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,54 @@
1+
_base_ = [
2+
'../_base_/datasets/semantickitti.py', '../_base_/models/minkunet.py',
3+
'../_base_/default_runtime.py'
4+
]
5+
6+
train_pipeline = [
7+
dict(type='LoadPointsFromFile', coord_type='LIDAR', load_dim=4, use_dim=4),
8+
dict(
9+
type='LoadAnnotations3D',
10+
with_bbox_3d=False,
11+
with_label_3d=False,
12+
with_seg_3d=True,
13+
seg_3d_dtype='np.int32',
14+
seg_offset=2**16,
15+
dataset_type='semantickitti'),
16+
dict(type='PointSegClassMapping'),
17+
dict(
18+
type='GlobalRotScaleTrans',
19+
rot_range=[0., 6.28318531],
20+
scale_ratio_range=[0.95, 1.05],
21+
translation_std=[0, 0, 0],
22+
),
23+
dict(type='Pack3DDetInputs', keys=['points', 'pts_semantic_mask'])
24+
]
25+
26+
train_dataloader = dict(
27+
sampler=dict(seed=0), dataset=dict(dataset=dict(pipeline=train_pipeline)))
28+
29+
lr = 0.24
30+
optim_wrapper = dict(
31+
type='AmpOptimWrapper',
32+
loss_scale='dynamic',
33+
optimizer=dict(
34+
type='SGD', lr=lr, weight_decay=0.0001, momentum=0.9, nesterov=True))
35+
36+
param_scheduler = [
37+
dict(
38+
type='LinearLR', start_factor=0.008, by_epoch=False, begin=0, end=125),
39+
dict(
40+
type='CosineAnnealingLR',
41+
begin=0,
42+
T_max=15,
43+
by_epoch=True,
44+
eta_min=1e-5,
45+
convert_to_iter_based=True)
46+
]
47+
48+
train_cfg = dict(type='EpochBasedTrainLoop', max_epochs=15, val_interval=1)
49+
val_cfg = dict(type='ValLoop')
50+
test_cfg = dict(type='TestLoop')
51+
52+
default_hooks = dict(checkpoint=dict(type='CheckpointHook', interval=1))
53+
randomness = dict(seed=0, deterministic=False, diff_rank_seed=True)
54+
env_cfg = dict(cudnn_benchmark=True)

mmdet3d/models/backbones/__init__.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
from .dgcnn import DGCNNBackbone
66
from .dla import DLANet
77
from .mink_resnet import MinkResNet
8+
from .minkunet_backbone import MinkUNetBackbone
89
from .multi_backbone import MultiBackbone
910
from .nostem_regnet import NoStemRegNet
1011
from .pointnet2_sa_msg import PointNet2SAMSG
@@ -14,5 +15,6 @@
1415
__all__ = [
1516
'ResNet', 'ResNetV1d', 'ResNeXt', 'SSDVGG', 'HRNet', 'NoStemRegNet',
1617
'SECOND', 'DGCNNBackbone', 'PointNet2SASSG', 'PointNet2SAMSG',
17-
'MultiBackbone', 'DLANet', 'MinkResNet', 'Asymm3DSpconv'
18+
'MultiBackbone', 'DLANet', 'MinkResNet', 'Asymm3DSpconv',
19+
'MinkUNetBackbone'
1820
]
Lines changed: 121 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,121 @@
1+
# Copyright (c) OpenMMLab. All rights reserved.
2+
from typing import List
3+
4+
from mmengine.model import BaseModule
5+
from mmengine.registry import MODELS
6+
from torch import Tensor, nn
7+
8+
from mmdet3d.models.layers import (TorchSparseConvModule,
9+
TorchSparseResidualBlock)
10+
from mmdet3d.models.layers.torchsparse import IS_TORCHSPARSE_AVAILABLE
11+
from mmdet3d.utils import OptMultiConfig
12+
13+
if IS_TORCHSPARSE_AVAILABLE:
14+
import torchsparse
15+
from torchsparse.tensor import SparseTensor
16+
else:
17+
SparseTensor = None
18+
19+
20+
@MODELS.register_module()
21+
class MinkUNetBackbone(BaseModule):
22+
r"""MinkUNet backbone with TorchSparse backend.
23+
24+
Refer to `implementation code <https://github.com/mit-han-lab/spvnas>`_.
25+
26+
Args:
27+
in_channels (int): Number of input voxel feature channels.
28+
Defaults to 4.
29+
base_channels (int): The input channels for first encoder layer.
30+
Defaults to 32.
31+
encoder_channels (List[int]): Convolutional channels of each encode
32+
layer. Defaults to [32, 64, 128, 256].
33+
decoder_channels (List[int]): Convolutional channels of each decode
34+
layer. Defaults to [256, 128, 96, 96].
35+
num_stages (int): Number of stages in encoder and decoder.
36+
Defaults to 4.
37+
init_cfg (dict or :obj:`ConfigDict` or List[dict or :obj:`ConfigDict`]
38+
, optional): Initialization config dict.
39+
"""
40+
41+
def __init__(self,
42+
in_channels: int = 4,
43+
base_channels: int = 32,
44+
encoder_channels: List[int] = [32, 64, 128, 256],
45+
decoder_channels: List[int] = [256, 128, 96, 96],
46+
num_stages: int = 4,
47+
init_cfg: OptMultiConfig = None) -> None:
48+
super().__init__(init_cfg)
49+
assert num_stages == len(encoder_channels) == len(decoder_channels)
50+
self.num_stages = num_stages
51+
self.conv_input = nn.Sequential(
52+
TorchSparseConvModule(in_channels, base_channels, kernel_size=3),
53+
TorchSparseConvModule(base_channels, base_channels, kernel_size=3))
54+
self.encoder = nn.ModuleList()
55+
self.decoder = nn.ModuleList()
56+
57+
encoder_channels.insert(0, base_channels)
58+
decoder_channels.insert(0, encoder_channels[-1])
59+
for i in range(num_stages):
60+
self.encoder.append(
61+
nn.Sequential(
62+
TorchSparseConvModule(
63+
encoder_channels[i],
64+
encoder_channels[i],
65+
kernel_size=2,
66+
stride=2),
67+
TorchSparseResidualBlock(
68+
encoder_channels[i],
69+
encoder_channels[i + 1],
70+
kernel_size=3),
71+
TorchSparseResidualBlock(
72+
encoder_channels[i + 1],
73+
encoder_channels[i + 1],
74+
kernel_size=3)))
75+
76+
self.decoder.append(
77+
nn.ModuleList([
78+
TorchSparseConvModule(
79+
decoder_channels[i],
80+
decoder_channels[i + 1],
81+
kernel_size=2,
82+
stride=2,
83+
transposed=True),
84+
nn.Sequential(
85+
TorchSparseResidualBlock(
86+
decoder_channels[i + 1] + encoder_channels[-2 - i],
87+
decoder_channels[i + 1],
88+
kernel_size=3),
89+
TorchSparseResidualBlock(
90+
decoder_channels[i + 1],
91+
decoder_channels[i + 1],
92+
kernel_size=3))
93+
]))
94+
95+
def forward(self, voxel_features: Tensor, coors: Tensor) -> SparseTensor:
96+
"""Forward function.
97+
98+
Args:
99+
voxel_features (Tensor): Voxel features in shape (N, C).
100+
coors (Tensor): Coordinates in shape (N, 4),
101+
the columns in the order of (x_idx, y_idx, z_idx, batch_idx).
102+
103+
Returns:
104+
SparseTensor: Backbone features.
105+
"""
106+
x = torchsparse.SparseTensor(voxel_features, coors)
107+
x = self.conv_input(x)
108+
laterals = [x]
109+
for encoder_layer in self.encoder:
110+
x = encoder_layer(x)
111+
laterals.append(x)
112+
laterals = laterals[:-1][::-1]
113+
114+
decoder_outs = []
115+
for i, decoder_layer in enumerate(self.decoder):
116+
x = decoder_layer[0](x)
117+
x = torchsparse.cat((x, laterals[i]))
118+
x = decoder_layer[1](x)
119+
decoder_outs.append(x)
120+
121+
return decoder_outs[-1]

mmdet3d/models/data_preprocessors/data_preprocessor.py

Lines changed: 77 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -415,6 +415,33 @@ def voxelize(self, points: List[torch.Tensor],
415415
coors.append(res_coors)
416416
voxels = torch.cat(voxels, dim=0)
417417
coors = torch.cat(coors, dim=0)
418+
elif self.voxel_type == 'minkunet':
419+
voxels, coors = [], []
420+
voxel_size = points[0].new_tensor(self.voxel_layer.voxel_size)
421+
for i, (res, data_sample) in enumerate(zip(points, data_samples)):
422+
res_coors = torch.round(res[:, :3] / voxel_size).int()
423+
res_coors -= res_coors.min(0)[0]
424+
425+
res_coors_numpy = res_coors.cpu().numpy()
426+
inds, voxel2point_map = self.sparse_quantize(
427+
res_coors_numpy, return_index=True, return_inverse=True)
428+
voxel2point_map = torch.from_numpy(voxel2point_map).cuda()
429+
if self.training:
430+
if len(inds) > 80000:
431+
inds = np.random.choice(inds, 80000, replace=False)
432+
inds = torch.from_numpy(inds).cuda()
433+
data_sample.gt_pts_seg.voxel_semantic_mask \
434+
= data_sample.gt_pts_seg.pts_semantic_mask[inds]
435+
res_voxel_coors = res_coors[inds]
436+
res_voxels = res[inds]
437+
res_voxel_coors = F.pad(
438+
res_voxel_coors, (0, 1), mode='constant', value=i)
439+
data_sample.voxel2point_map = voxel2point_map.long()
440+
voxels.append(res_voxels)
441+
coors.append(res_voxel_coors)
442+
voxels = torch.cat(voxels, dim=0)
443+
coors = torch.cat(coors, dim=0)
444+
418445
else:
419446
raise ValueError(f'Invalid voxelization type {self.voxel_type}')
420447

@@ -445,3 +472,53 @@ def get_voxel_seg(self, res_coors: torch.Tensor, data_sample: SampleList):
445472
_, _, point2voxel_map = dynamic_scatter_3d(pseudo_tensor,
446473
res_coors, 'mean', True)
447474
data_sample.gt_pts_seg.point2voxel_map = point2voxel_map
475+
476+
def ravel_hash(self, x: np.ndarray) -> np.ndarray:
477+
"""Get voxel coordinates hash for np.unique().
478+
479+
Args:
480+
x (np.ndarray): The voxel coordinates of points, Nx3.
481+
482+
Returns:
483+
np.ndarray: Voxels coordinates hash.
484+
"""
485+
assert x.ndim == 2, x.shape
486+
487+
x = x - np.min(x, axis=0)
488+
x = x.astype(np.uint64, copy=False)
489+
xmax = np.max(x, axis=0).astype(np.uint64) + 1
490+
491+
h = np.zeros(x.shape[0], dtype=np.uint64)
492+
for k in range(x.shape[1] - 1):
493+
h += x[:, k]
494+
h *= xmax[k + 1]
495+
h += x[:, -1]
496+
return h
497+
498+
def sparse_quantize(self,
499+
coords: np.ndarray,
500+
return_index: bool = False,
501+
return_inverse: bool = False) -> List[np.ndarray]:
502+
"""Sparse Quantization for voxel coordinates used in Minkunet.
503+
504+
Args:
505+
coords (np.ndarray): The voxel coordinates of points, Nx3.
506+
return_index (bool): Whether to return the indices of the
507+
unique coords, shape (M,).
508+
return_inverse (bool): Whether to return the indices of the
509+
original coords shape (N,).
510+
511+
Returns:
512+
List[np.ndarray] or None: Return index and inverse map if
513+
return_index and return_inverse is True.
514+
"""
515+
_, indices, inverse_indices = np.unique(
516+
self.ravel_hash(coords), return_index=True, return_inverse=True)
517+
coords = coords[indices]
518+
519+
outputs = []
520+
if return_index:
521+
outputs += [indices]
522+
if return_inverse:
523+
outputs += [inverse_indices]
524+
return outputs
Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,11 @@
11
# Copyright (c) OpenMMLab. All rights reserved.
22
from .cylinder3d_head import Cylinder3DHead
33
from .dgcnn_head import DGCNNHead
4+
from .minkunet_head import MinkUNetHead
45
from .paconv_head import PAConvHead
56
from .pointnet2_head import PointNet2Head
67

7-
__all__ = ['PointNet2Head', 'DGCNNHead', 'PAConvHead', 'Cylinder3DHead']
8+
__all__ = [
9+
'PointNet2Head', 'DGCNNHead', 'PAConvHead', 'Cylinder3DHead',
10+
'MinkUNetHead'
11+
]

0 commit comments

Comments
 (0)