Skip to content

Commit 762e3b5

Browse files
[Feature] Support DSVT training (#2738)
Co-authored-by: JingweiZhang12 <[email protected]> Co-authored-by: sjh <sunjiahao1999>
1 parent 5b88c7b commit 762e3b5

File tree

14 files changed

+875
-86
lines changed

14 files changed

+875
-86
lines changed

mmdet3d/models/dense_heads/centerpoint_head.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -101,7 +101,7 @@ def forward(self, x):
101101
Returns:
102102
dict[str: torch.Tensor]: contains the following keys:
103103
104-
-reg torch.Tensor): 2D regression value with the
104+
-reg (torch.Tensor): 2D regression value with the
105105
shape of [B, 2, H, W].
106106
-height (torch.Tensor): Height value with the
107107
shape of [B, 1, H, W].
@@ -217,7 +217,7 @@ def forward(self, x):
217217
Returns:
218218
dict[str: torch.Tensor]: contains the following keys:
219219
220-
-reg torch.Tensor): 2D regression value with the
220+
-reg (torch.Tensor): 2D regression value with the
221221
shape of [B, 2, H, W].
222222
-height (torch.Tensor): Height value with the
223223
shape of [B, 1, H, W].

mmdet3d/models/necks/second_fpn.py

Lines changed: 11 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,10 @@ class SECONDFPN(BaseModule):
2121
upsample_cfg (dict): Config dict of upsample layers.
2222
conv_cfg (dict): Config dict of conv layers.
2323
use_conv_for_no_stride (bool): Whether to use conv when stride is 1.
24+
init_cfg (dict or :obj:`ConfigDict` or list[dict or :obj:`ConfigDict`],
25+
optional): Initialization config dict. Defaults to
26+
[dict(type='Kaiming', layer='ConvTranspose2d'),
27+
dict(type='Constant', layer='NaiveSyncBatchNorm2d', val=1.0)].
2428
"""
2529

2630
def __init__(self,
@@ -31,7 +35,13 @@ def __init__(self,
3135
upsample_cfg=dict(type='deconv', bias=False),
3236
conv_cfg=dict(type='Conv2d', bias=False),
3337
use_conv_for_no_stride=False,
34-
init_cfg=None):
38+
init_cfg=[
39+
dict(type='Kaiming', layer='ConvTranspose2d'),
40+
dict(
41+
type='Constant',
42+
layer='NaiveSyncBatchNorm2d',
43+
val=1.0)
44+
]):
3545
# if for GroupNorm,
3646
# cfg is dict(type='GN', num_groups=num_groups, eps=1e-3, affine=True)
3747
super(SECONDFPN, self).__init__(init_cfg=init_cfg)
@@ -64,12 +74,6 @@ def __init__(self,
6474
deblocks.append(deblock)
6575
self.deblocks = nn.ModuleList(deblocks)
6676

67-
if init_cfg is None:
68-
self.init_cfg = [
69-
dict(type='Kaiming', layer='ConvTranspose2d'),
70-
dict(type='Constant', layer='NaiveSyncBatchNorm2d', val=1.0)
71-
]
72-
7377
def forward(self, x):
7478
"""Forward function.
7579

mmdet3d/structures/bbox_3d/base_box3d.py

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -275,12 +275,13 @@ def in_range_3d(
275275
Tensor: A binary vector indicating whether each point is inside the
276276
reference range.
277277
"""
278-
in_range_flags = ((self.tensor[:, 0] > box_range[0])
279-
& (self.tensor[:, 1] > box_range[1])
280-
& (self.tensor[:, 2] > box_range[2])
281-
& (self.tensor[:, 0] < box_range[3])
282-
& (self.tensor[:, 1] < box_range[4])
283-
& (self.tensor[:, 2] < box_range[5]))
278+
gravity_center = self.gravity_center
279+
in_range_flags = ((gravity_center[:, 0] > box_range[0])
280+
& (gravity_center[:, 1] > box_range[1])
281+
& (gravity_center[:, 2] > box_range[2])
282+
& (gravity_center[:, 0] < box_range[3])
283+
& (gravity_center[:, 1] < box_range[4])
284+
& (gravity_center[:, 2] < box_range[5]))
284285
return in_range_flags
285286

286287
@abstractmethod

projects/DSVT/README.md

Lines changed: 13 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -57,17 +57,25 @@ python tools/test.py projects/DSVT/configs/dsvt_voxel032_res-second_secfpn_8xb1-
5757

5858
### Training commands
5959

60-
The support of training DSVT is on the way.
60+
In MMDetection3D's root directory, run the following command to test the model:
61+
62+
```bash
63+
tools/dist_train.sh projects/DSVT/configs/dsvt_voxel032_res-second_secfpn_8xb1-cyclic-12e_waymoD5-3d-3class.py 8 --sync_bn torch
64+
```
6165

6266
## Results and models
6367

6468
### Waymo
6569

66-
| Middle Encoder | Backbone | Load Interval | Voxel type (voxel size) | Multi-Class NMS | Multi-frames | Mem (GB) | Inf time (fps) | mAP@L1 | mAPH@L1 | mAP@L2 | **mAPH@L2** | Download |
67-
| :------------------------------------------------------------------------------------: | :-----------------------------------------------------------------------------------------: | :-----------: | :---------------------: | :-------------: | :----------: | :------: | :------------: | :----: | :-----: | :----: | :---------: | :------: |
68-
| [DSVT](./configs/dsvt_voxel032_res-second_secfpn_8xb1-cyclic-12e_waymoD5-3d-3class.py) | [ResSECOND](./configs/dsvt_voxel032_res-second_secfpn_8xb1-cyclic-12e_waymoD5-3d-3class.py) | 5 | voxel (0.32) || × | | | 75.2 | 72.2 | 68.9 | 66.1 | |
70+
| Middle Encoder | Backbone | Load Interval | Voxel type (voxel size) | Multi-Class NMS | Multi-frames | mAP@L1 | mAPH@L1 | mAP@L2 | **mAPH@L2** | Download |
71+
| :------------------------------------------------------------------------------------: | :-----------------------------------------------------------------------------------------: | :-----------: | :---------------------: | :-------------: | :----------: | :----: | :-----: | :----: | :---------: | :--------------------------------------------------------------------------------------------------------------------------------------------------------------: |
72+
| [DSVT](./configs/dsvt_voxel032_res-second_secfpn_8xb1-cyclic-12e_waymoD5-3d-3class.py) | [ResSECOND](./configs/dsvt_voxel032_res-second_secfpn_8xb1-cyclic-12e_waymoD5-3d-3class.py) | 5 | voxel (0.32) || × | 75.5 | 72.4 | 69.2 | 66.3 | \[log\](\<https://download.openmmlab.com/mmdetection3d/v1.1.0_models/dsvt/dsvt_voxel032_res-second_secfpn_8xb1-cyclic-12e_waymoD5-3d-3class_20230917_102130.log) |
73+
74+
**Note**:
75+
76+
- `ResSECOND` denotes the base block in SECOND has residual layers.
6977

70-
**Note** that `ResSECOND` denotes the base block in SECOND has residual layers.
78+
- Regrettably, we are unable to provide the pre-trained model weights due to [Waymo Dataset License Agreement](https://waymo.com/open/terms/), so we only provide the training logs as shown above.
7179

7280
## Citation
7381

projects/DSVT/configs/dsvt_voxel032_res-second_secfpn_8xb1-cyclic-12e_waymoD5-3d-3class.py

Lines changed: 103 additions & 43 deletions
Original file line numberDiff line numberDiff line change
@@ -88,25 +88,28 @@
8888
loss_cls=dict(
8989
type='mmdet.GaussianFocalLoss', reduction='mean', loss_weight=1.0),
9090
loss_bbox=dict(type='mmdet.L1Loss', reduction='mean', loss_weight=2.0),
91+
loss_iou=dict(type='mmdet.L1Loss', reduction='sum', loss_weight=1.0),
92+
loss_reg_iou=dict(
93+
type='mmdet3d.DIoU3DLoss', reduction='mean', loss_weight=2.0),
9194
norm_bbox=True),
9295
# model training and testing settings
9396
train_cfg=dict(
94-
pts=dict(
95-
grid_size=grid_size,
96-
voxel_size=voxel_size,
97-
out_size_factor=4,
98-
dense_reg=1,
99-
gaussian_overlap=0.1,
100-
max_objs=500,
101-
min_radius=2,
102-
code_weights=[1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0])),
97+
grid_size=grid_size,
98+
voxel_size=voxel_size,
99+
point_cloud_range=point_cloud_range,
100+
out_size_factor=1,
101+
dense_reg=1,
102+
gaussian_overlap=0.1,
103+
max_objs=500,
104+
min_radius=2,
105+
code_weights=[1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0]),
103106
test_cfg=dict(
104107
max_per_img=500,
105108
max_pool_nms=False,
106109
min_radius=[4, 12, 10, 1, 0.85, 0.175],
107110
iou_rectifier=[[0.68, 0.71, 0.65]],
108111
pc_range=[-80, -80],
109-
out_size_factor=4,
112+
out_size_factor=1,
110113
voxel_size=voxel_size[:2],
111114
nms_type='rotate',
112115
multi_class_nms=True,
@@ -128,6 +131,8 @@
128131
coord_type='LIDAR',
129132
load_dim=6,
130133
use_dim=[0, 1, 2, 3, 4],
134+
norm_intensity=True,
135+
norm_elongation=True,
131136
backend_args=backend_args),
132137
backend_args=backend_args)
133138

@@ -138,25 +143,22 @@
138143
load_dim=6,
139144
use_dim=5,
140145
norm_intensity=True,
146+
norm_elongation=True,
141147
backend_args=backend_args),
142-
# Add this if using `MultiFrameDeformableDecoderRPN`
143-
# dict(
144-
# type='LoadPointsFromMultiSweeps',
145-
# sweeps_num=9,
146-
# load_dim=6,
147-
# use_dim=[0, 1, 2, 3, 4],
148-
# pad_empty_sweeps=True,
149-
# remove_close=True),
150148
dict(type='LoadAnnotations3D', with_bbox_3d=True, with_label_3d=True),
151149
dict(type='ObjectSample', db_sampler=db_sampler),
150+
dict(
151+
type='RandomFlip3D',
152+
sync_2d=False,
153+
flip_ratio_bev_horizontal=0.5,
154+
flip_ratio_bev_vertical=0.5),
152155
dict(
153156
type='GlobalRotScaleTrans',
154157
rot_range=[-0.78539816, 0.78539816],
155158
scale_ratio_range=[0.95, 1.05],
156-
translation_std=[0.5, 0.5, 0]),
157-
dict(type='PointsRangeFilter', point_cloud_range=point_cloud_range),
158-
dict(type='ObjectRangeFilter', point_cloud_range=point_cloud_range),
159-
dict(type='ObjectNameFilter', classes=class_names),
159+
translation_std=[0.5, 0.5, 0.5]),
160+
dict(type='PointsRangeFilter3D', point_cloud_range=point_cloud_range),
161+
dict(type='ObjectRangeFilter3D', point_cloud_range=point_cloud_range),
160162
dict(type='PointShuffle'),
161163
dict(
162164
type='Pack3DDetInputs',
@@ -172,25 +174,34 @@
172174
norm_intensity=True,
173175
norm_elongation=True,
174176
backend_args=backend_args),
177+
dict(type='PointsRangeFilter3D', point_cloud_range=point_cloud_range),
175178
dict(
176-
type='MultiScaleFlipAug3D',
177-
img_scale=(1333, 800),
178-
pts_scale_ratio=1,
179-
flip=False,
180-
transforms=[
181-
dict(
182-
type='GlobalRotScaleTrans',
183-
rot_range=[0, 0],
184-
scale_ratio_range=[1., 1.],
185-
translation_std=[0, 0, 0]),
186-
dict(type='RandomFlip3D'),
187-
dict(
188-
type='PointsRangeFilter', point_cloud_range=point_cloud_range)
189-
]),
190-
dict(type='Pack3DDetInputs', keys=['points'])
179+
type='Pack3DDetInputs',
180+
keys=['points'],
181+
meta_keys=['box_type_3d', 'sample_idx', 'context_name', 'timestamp'])
191182
]
192183

193184
dataset_type = 'WaymoDataset'
185+
train_dataloader = dict(
186+
batch_size=1,
187+
num_workers=4,
188+
persistent_workers=True,
189+
sampler=dict(type='DefaultSampler', shuffle=True),
190+
dataset=dict(
191+
type=dataset_type,
192+
data_root=data_root,
193+
ann_file='waymo_infos_train.pkl',
194+
data_prefix=dict(pts='training/velodyne', sweeps='training/velodyne'),
195+
pipeline=train_pipeline,
196+
modality=input_modality,
197+
test_mode=False,
198+
metainfo=metainfo,
199+
# we use box_type_3d='LiDAR' in kitti and nuscenes dataset
200+
# and box_type_3d='Depth' in sunrgbd and scannet dataset.
201+
box_type_3d='LiDAR',
202+
# load one frame every five frames
203+
load_interval=5,
204+
backend_args=backend_args))
194205
val_dataloader = dict(
195206
batch_size=4,
196207
num_workers=4,
@@ -212,18 +223,59 @@
212223

213224
val_evaluator = dict(
214225
type='WaymoMetric',
215-
ann_file='./data/waymo/kitti_format/waymo_infos_val.pkl',
216226
waymo_bin_file='./data/waymo/waymo_format/gt.bin',
217-
data_root='./data/waymo/waymo_format',
218-
backend_args=backend_args,
219-
convert_kitti_format=False,
220-
idx2metainfo='./data/waymo/waymo_format/idx2metainfo.pkl')
227+
result_prefix='./dsvt_pred')
221228
test_evaluator = val_evaluator
222229

223230
vis_backends = [dict(type='LocalVisBackend')]
224231
visualizer = dict(
225232
type='Det3DLocalVisualizer', vis_backends=vis_backends, name='visualizer')
226233

234+
# schedules
235+
lr = 1e-5
236+
optim_wrapper = dict(
237+
type='OptimWrapper',
238+
optimizer=dict(type='AdamW', lr=lr, weight_decay=0.05, betas=(0.9, 0.99)),
239+
clip_grad=dict(max_norm=10, norm_type=2))
240+
param_scheduler = [
241+
dict(
242+
type='CosineAnnealingLR',
243+
T_max=1.2,
244+
eta_min=lr * 100,
245+
begin=0,
246+
end=1.2,
247+
by_epoch=True,
248+
convert_to_iter_based=True),
249+
dict(
250+
type='CosineAnnealingLR',
251+
T_max=10.8,
252+
eta_min=lr * 1e-4,
253+
begin=1.2,
254+
end=12,
255+
by_epoch=True,
256+
convert_to_iter_based=True),
257+
# momentum scheduler
258+
dict(
259+
type='CosineAnnealingMomentum',
260+
T_max=1.2,
261+
eta_min=0.85,
262+
begin=0,
263+
end=1.2,
264+
by_epoch=True,
265+
convert_to_iter_based=True),
266+
dict(
267+
type='CosineAnnealingMomentum',
268+
T_max=10.8,
269+
eta_min=0.95,
270+
begin=1.2,
271+
end=12,
272+
by_epoch=True,
273+
convert_to_iter_based=True)
274+
]
275+
276+
# runtime settings
277+
train_cfg = dict(by_epoch=True, max_epochs=12, val_interval=1)
278+
227279
# runtime settings
228280
val_cfg = dict()
229281
test_cfg = dict()
@@ -236,4 +288,12 @@
236288

237289
default_hooks = dict(
238290
logger=dict(type='LoggerHook', interval=50),
239-
checkpoint=dict(type='CheckpointHook', interval=5))
291+
checkpoint=dict(type='CheckpointHook', interval=1))
292+
custom_hooks = [
293+
dict(
294+
type='DisableAugHook',
295+
disable_after_epoch=11,
296+
disable_aug_list=[
297+
'GlobalRotScaleTrans', 'RandomFlip3D', 'ObjectSample'
298+
])
299+
]

projects/DSVT/dsvt/__init__.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,15 @@
1+
from .disable_aug_hook import DisableAugHook
12
from .dsvt import DSVT
23
from .dsvt_head import DSVTCenterHead
34
from .dsvt_transformer import DSVTMiddleEncoder
45
from .dynamic_pillar_vfe import DynamicPillarVFE3D
56
from .map2bev import PointPillarsScatter3D
67
from .res_second import ResSECOND
8+
from .transforms_3d import ObjectRangeFilter3D, PointsRangeFilter3D
79
from .utils import DSVTBBoxCoder
810

911
__all__ = [
1012
'DSVTCenterHead', 'DSVT', 'DSVTMiddleEncoder', 'DynamicPillarVFE3D',
11-
'PointPillarsScatter3D', 'ResSECOND', 'DSVTBBoxCoder'
13+
'PointPillarsScatter3D', 'ResSECOND', 'DSVTBBoxCoder',
14+
'ObjectRangeFilter3D', 'PointsRangeFilter3D', 'DisableAugHook'
1215
]

0 commit comments

Comments
 (0)