Skip to content

Commit dfcf542

Browse files
authored
[Feature] Cylinder3d segmentor (#2344)
* update * add cylinder3d_backbone * add test segmentor * add cfg * add test backbone * rename test cylinder3d backbone * midway * update, pass validation * fix test * update cfg
1 parent afa4479 commit dfcf542

File tree

11 files changed

+790
-12
lines changed

11 files changed

+790
-12
lines changed

configs/_base_/models/cylinder3d.py

Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,41 @@
1+
grid_shape = [480, 360, 32]
2+
model = dict(
3+
type='Cylinder3D',
4+
data_preprocessor=dict(
5+
type='Det3DDataPreprocessor',
6+
voxel=True,
7+
voxel_type='cylindrical',
8+
voxel_layer=dict(
9+
grid_shape=grid_shape,
10+
point_cloud_range=[0, -3.14159265359, -4, 50, 3.14159265359, 2],
11+
max_num_points=-1,
12+
max_voxels=-1,
13+
),
14+
),
15+
voxel_encoder=dict(
16+
type='SegVFE',
17+
feat_channels=[64, 128, 256, 256],
18+
in_channels=6,
19+
with_voxel_center=True,
20+
feat_compression=16,
21+
return_point_feats=False),
22+
backbone=dict(
23+
type='Asymm3DSpconv',
24+
grid_size=grid_shape,
25+
input_channels=16,
26+
base_channels=32,
27+
norm_cfg=dict(type='BN1d', eps=1e-5, momentum=0.1)),
28+
decode_head=dict(
29+
type='Cylinder3DHead',
30+
channels=128,
31+
num_classes=20,
32+
loss_ce=dict(
33+
type='mmdet.CrossEntropyLoss',
34+
use_sigmoid=False,
35+
class_weight=None,
36+
loss_weight=1.0),
37+
loss_lovasz=dict(type='LovaszLoss', loss_weight=1.0, reduction='none'),
38+
),
39+
train_cfg=None,
40+
test_cfg=dict(mode='whole'),
41+
)
Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,37 @@
1+
_base_ = [
2+
'../_base_/datasets/semantickitti.py', '../_base_/models/cylinder3d.py',
3+
'../_base_/default_runtime.py'
4+
]
5+
6+
# optimizer
7+
# This schedule is mainly used by models on nuScenes dataset
8+
lr = 0.001
9+
optim_wrapper = dict(
10+
type='OptimWrapper',
11+
optimizer=dict(type='AdamW', lr=lr, weight_decay=0.01))
12+
13+
train_cfg = dict(type='EpochBasedTrainLoop', max_epochs=36, val_interval=1)
14+
val_cfg = dict(type='ValLoop')
15+
test_cfg = dict(type='TestLoop')
16+
17+
# learning rate
18+
param_scheduler = [
19+
dict(
20+
type='LinearLR', start_factor=0.001, by_epoch=False, begin=0,
21+
end=1000),
22+
dict(
23+
type='MultiStepLR',
24+
begin=0,
25+
end=36,
26+
by_epoch=True,
27+
milestones=[30],
28+
gamma=0.1)
29+
]
30+
31+
# Default setting for scaling LR automatically
32+
# - `enable` means enable scaling LR automatically
33+
# or not by default.
34+
# - `base_batch_size` = (8 GPUs) x (4 samples per GPU).
35+
# auto_scale_lr = dict(enable=False, base_batch_size=32)
36+
37+
default_hooks = dict(checkpoint=dict(type='CheckpointHook', interval=5))

mmdet3d/datasets/seg3d_dataset.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -255,8 +255,8 @@ def parse_data_info(self, info: dict) -> dict:
255255
osp.join(
256256
self.data_prefix.get('pts', ''),
257257
info['lidar_points']['lidar_path'])
258-
259-
info['num_pts_feats'] = info['lidar_points']['num_pts_feats']
258+
if 'num_pts_feats' in info['lidar_points']:
259+
info['num_pts_feats'] = info['lidar_points']['num_pts_feats']
260260
info['lidar_path'] = info['lidar_points']['lidar_path']
261261

262262
if self.modality['use_camera']:

mmdet3d/models/backbones/__init__.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
# Copyright (c) OpenMMLab. All rights reserved.
22
from mmdet.models.backbones import SSDVGG, HRNet, ResNet, ResNetV1d, ResNeXt
33

4+
from .cylinder3d import Asymm3DSpconv
45
from .dgcnn import DGCNNBackbone
56
from .dla import DLANet
67
from .mink_resnet import MinkResNet
@@ -13,5 +14,5 @@
1314
__all__ = [
1415
'ResNet', 'ResNetV1d', 'ResNeXt', 'SSDVGG', 'HRNet', 'NoStemRegNet',
1516
'SECOND', 'DGCNNBackbone', 'PointNet2SASSG', 'PointNet2SAMSG',
16-
'MultiBackbone', 'DLANet', 'MinkResNet'
17+
'MultiBackbone', 'DLANet', 'MinkResNet', 'Asymm3DSpconv'
1718
]
Lines changed: 110 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,110 @@
1+
# Copyright (c) OpenMMLab. All rights reserved.
2+
r"""Modified from Cylinder3D.
3+
4+
Please refer to `Cylinder3D github page
5+
<https://github.com/xinge008/Cylinder3D>`_ for details
6+
"""
7+
8+
from typing import List
9+
10+
import numpy as np
11+
import torch
12+
from mmcv.ops import SparseConvTensor
13+
from mmengine.model import BaseModule
14+
15+
from mmdet3d.models.layers.sparse_block import (AsymmeDownBlock, AsymmeUpBlock,
16+
AsymmResBlock, DDCMBlock)
17+
from mmdet3d.registry import MODELS
18+
from mmdet3d.utils import ConfigType
19+
20+
21+
@MODELS.register_module()
22+
class Asymm3DSpconv(BaseModule):
23+
"""Asymmetrical 3D convolution networks.
24+
25+
Args:
26+
grid_size (int): Size of voxel grids.
27+
input_channels (int): Input channels of the block.
28+
base_channels (int): Initial size of feature channels before
29+
feeding into Encoder-Decoder structure. Defaults to 16.
30+
backbone_depth (int): The depth of backbone. The backbone contains
31+
downblocks and upblocks with the number of backbone_depth.
32+
height_pooing (List[bool]): List indicating which downblocks perform
33+
height pooling.
34+
norm_cfg (:obj:`ConfigDict` or dict): Config dict for normalization
35+
layer. Defaults to dict(type='BN1d', eps=1e-3, momentum=0.01)).
36+
init_cfg (dict, optional): Initialization config.
37+
Defaults to None.
38+
"""
39+
40+
def __init__(self,
41+
grid_size: int,
42+
input_channels: int,
43+
base_channels: int = 16,
44+
backbone_depth: int = 4,
45+
height_pooing: List[bool] = [True, True, False, False],
46+
norm_cfg: ConfigType = dict(
47+
type='BN1d', eps=1e-3, momentum=0.01),
48+
init_cfg=None):
49+
super().__init__(init_cfg=init_cfg)
50+
51+
self.grid_size = grid_size
52+
self.backbone_depth = backbone_depth
53+
self.down_context = AsymmResBlock(
54+
input_channels, base_channels, indice_key='pre', norm_cfg=norm_cfg)
55+
56+
self.down_block_list = torch.nn.ModuleList()
57+
self.up_block_list = torch.nn.ModuleList()
58+
for i in range(self.backbone_depth):
59+
self.down_block_list.append(
60+
AsymmeDownBlock(
61+
2**i * base_channels,
62+
2**(i + 1) * base_channels,
63+
height_pooling=height_pooing[i],
64+
indice_key='down' + str(i),
65+
norm_cfg=norm_cfg))
66+
if i == self.backbone_depth - 1:
67+
self.up_block_list.append(
68+
AsymmeUpBlock(
69+
2**(i + 1) * base_channels,
70+
2**(i + 1) * base_channels,
71+
up_key='down' + str(i),
72+
indice_key='up' + str(self.backbone_depth - 1 - i),
73+
norm_cfg=norm_cfg))
74+
else:
75+
self.up_block_list.append(
76+
AsymmeUpBlock(
77+
2**(i + 2) * base_channels,
78+
2**(i + 1) * base_channels,
79+
up_key='down' + str(i),
80+
indice_key='up' + str(self.backbone_depth - 1 - i),
81+
norm_cfg=norm_cfg))
82+
83+
self.ddcm = DDCMBlock(
84+
2 * base_channels,
85+
2 * base_channels,
86+
indice_key='ddcm',
87+
norm_cfg=norm_cfg)
88+
89+
def forward(self, voxel_features: torch.Tensor, coors: torch.Tensor,
90+
batch_size: int) -> SparseConvTensor:
91+
"""Forward pass."""
92+
coors = coors.int()
93+
ret = SparseConvTensor(voxel_features, coors, np.array(self.grid_size),
94+
batch_size)
95+
ret = self.down_context(ret)
96+
97+
down_skip_list = []
98+
down_pool = ret
99+
for i in range(self.backbone_depth):
100+
down_pool, down_skip = self.down_block_list[i](down_pool)
101+
down_skip_list.append(down_skip)
102+
103+
up = down_pool
104+
for i in range(self.backbone_depth - 1, -1, -1):
105+
up = self.up_block_list[i](up, down_skip_list[i])
106+
107+
ddcm = self.ddcm(up)
108+
ddcm.features = torch.cat((ddcm.features, up.features), 1)
109+
110+
return ddcm

mmdet3d/models/decode_heads/cylinder3d_head.py

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,7 @@ class Cylinder3DHead(Base3DDecodeHead):
3939
conv_seg_kernel_size (int): The kernel size used in conv_seg.
4040
Defaults to 3.
4141
ignore_index (int): The label index to be ignored. When using masked
42-
BCE loss, ignore_index should be set to None. Defaults to 0.
42+
BCE loss, ignore_index should be set to None. Defaults to 19.
4343
init_cfg (dict or :obj:`ConfigDict` or list[dict or :obj:`ConfigDict`],
4444
optional): Initialization config dict. Defaults to None.
4545
"""
@@ -59,7 +59,7 @@ def __init__(self,
5959
loss_lovasz: ConfigType = dict(
6060
type='LovaszLoss', loss_weight=1.0),
6161
conv_seg_kernel_size: int = 3,
62-
ignore_index: int = 0,
62+
ignore_index: int = 19,
6363
init_cfg: OptMultiConfig = None) -> None:
6464
super(Cylinder3DHead, self).__init__(
6565
channels=channels,
@@ -116,8 +116,6 @@ def loss_by_feat(self, seg_logit: SparseConvTensor,
116116
loss = dict()
117117
loss['loss_ce'] = self.loss_ce(
118118
seg_logit_feat, seg_label, ignore_index=self.ignore_index)
119-
seg_logit_feat = seg_logit_feat.permute(1, 0)[None, :, :,
120-
None] # pseudo BCHW
121119
loss['loss_lovasz'] = self.loss_lovasz(
122120
seg_logit_feat, seg_label, ignore_index=self.ignore_index)
123121

0 commit comments

Comments
 (0)