Skip to content

Commit 8fb2cf6

Browse files
[Feature] Support inference of DSVT in projects (#2606)
* support inference * align inference precision * add readme * polish docs * polish docs
1 parent 456b740 commit 8fb2cf6

18 files changed

+2549
-1
lines changed

mmdet3d/datasets/transforms/loading.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -579,6 +579,8 @@ class LoadPointsFromFile(BaseTransform):
579579
use_color (bool): Whether to use color features. Defaults to False.
580580
norm_intensity (bool): Whether to normlize the intensity. Defaults to
581581
False.
582+
norm_elongation (bool): Whether to normlize the elongation. This is
583+
usually used in Waymo dataset.Defaults to False.
582584
backend_args (dict, optional): Arguments to instantiate the
583585
corresponding backend. Defaults to None.
584586
"""
@@ -590,6 +592,7 @@ def __init__(self,
590592
shift_height: bool = False,
591593
use_color: bool = False,
592594
norm_intensity: bool = False,
595+
norm_elongation: bool = False,
593596
backend_args: Optional[dict] = None) -> None:
594597
self.shift_height = shift_height
595598
self.use_color = use_color
@@ -603,6 +606,7 @@ def __init__(self,
603606
self.load_dim = load_dim
604607
self.use_dim = use_dim
605608
self.norm_intensity = norm_intensity
609+
self.norm_elongation = norm_elongation
606610
self.backend_args = backend_args
607611

608612
def _load_points(self, pts_filename: str) -> np.ndarray:
@@ -646,6 +650,10 @@ def transform(self, results: dict) -> dict:
646650
assert len(self.use_dim) >= 4, \
647651
f'When using intensity norm, expect used dimensions >= 4, got {len(self.use_dim)}' # noqa: E501
648652
points[:, 3] = np.tanh(points[:, 3])
653+
if self.norm_elongation:
654+
assert len(self.use_dim) >= 5, \
655+
f'When using elongation norm, expect used dimensions >= 5, got {len(self.use_dim)}' # noqa: E501
656+
points[:, 4] = np.tanh(points[:, 4])
649657
attribute_dims = None
650658

651659
if self.shift_height:
@@ -682,6 +690,8 @@ def __repr__(self) -> str:
682690
repr_str += f'backend_args={self.backend_args}, '
683691
repr_str += f'load_dim={self.load_dim}, '
684692
repr_str += f'use_dim={self.use_dim})'
693+
repr_str += f'norm_intensity={self.norm_intensity})'
694+
repr_str += f'norm_elongation={self.norm_elongation})'
685695
return repr_str
686696

687697

mmdet3d/models/necks/second_fpn.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -74,7 +74,8 @@ def forward(self, x):
7474
"""Forward function.
7575
7676
Args:
77-
x (torch.Tensor): 4D Tensor in (N, C, H, W) shape.
77+
x (List[torch.Tensor]): Multi-level features with 4D Tensor in
78+
(N, C, H, W) shape.
7879
7980
Returns:
8081
list[torch.Tensor]: Multi-level feature maps.

projects/DSVT/README.md

Lines changed: 81 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,81 @@
1+
# DSVT: Dynamic Sparse Voxel Transformer with Rotated Sets
2+
3+
> [DSVT: Dynamic Sparse Voxel Transformer with Rotated Sets](https://arxiv.org/abs/2301.06051)
4+
5+
<!-- [ALGORITHM] -->
6+
7+
## Abstract
8+
9+
Designing an efficient yet deployment-friendly 3D backbone to handle sparse point clouds is a fundamental problem
10+
in 3D perception. Compared with the customized sparse
11+
convolution, the attention mechanism in Transformers is
12+
more appropriate for flexibly modeling long-range relationships and is easier to be deployed in real-world applications.
13+
However, due to the sparse characteristics of point clouds,
14+
it is non-trivial to apply a standard transformer on sparse
15+
points. In this paper, we present Dynamic Sparse Voxel
16+
Transformer (DSVT), a single-stride window-based voxel
17+
Transformer backbone for outdoor 3D perception. In order
18+
to efficiently process sparse points in parallel, we propose
19+
Dynamic Sparse Window Attention, which partitions a series
20+
of local regions in each window according to its sparsity
21+
and then computes the features of all regions in a fully parallel manner. To allow the cross-set connection, we design
22+
a rotated set partitioning strategy that alternates between
23+
two partitioning configurations in consecutive self-attention
24+
layers. To support effective downsampling and better encode geometric information, we also propose an attentionstyle 3D pooling module on sparse points, which is powerful
25+
and deployment-friendly without utilizing any customized
26+
CUDA operations. Our model achieves state-of-the-art performance with a broad range of 3D perception tasks. More
27+
importantly, DSVT can be easily deployed by TensorRT with
28+
real-time inference speed (27Hz). Code will be available at
29+
https://github.com/Haiyang-W/DSVT.
30+
31+
<div align=center>
32+
<img src="https://github-production-user-asset-6210df.s3.amazonaws.com/34888372/245692705-e61be20c-2a7d-4ab9-85e3-b36f662c1bdf.png" width="800"/>
33+
</div>
34+
35+
## Introduction
36+
37+
We implement DSVT and provide the results on Waymo dataset.
38+
39+
## Usage
40+
41+
<!-- For a typical model, this section should contain the commands for training and testing. You are also suggested to dump your environment specification to env.yml by `conda env export > env.yml`. -->
42+
43+
### Installation
44+
45+
```shell
46+
pip install torch_scatter==2.0.9
47+
python projects/DSVT/setup.py develop # compile `ingroup_inds_op` cuda operation
48+
```
49+
50+
### Testing commands
51+
52+
In MMDetection3D's root directory, run the following command to test the model:
53+
54+
```bash
55+
python tools/test.py projects/DSVT/configs/dsvt_voxel032_res-second_secfpn_8xb1-cyclic-12e_waymoD5-3d-3class.py ${CHECKPOINT_PATH}
56+
```
57+
58+
### Training commands
59+
60+
The support of training DSVT is on the way.
61+
62+
## Results and models
63+
64+
### Waymo
65+
66+
| Middle Encoder | Backbone | Load Interval | Voxel type (voxel size) | Multi-Class NMS | Multi-frames | Mem (GB) | Inf time (fps) | mAP@L1 | mAPH@L1 | mAP@L2 | **mAPH@L2** | Download |
67+
| :------------------------------------------------------------------------------------: | :-----------------------------------------------------------------------------------------: | :-----------: | :---------------------: | :-------------: | :----------: | :------: | :------------: | :----: | :-----: | :----: | :---------: | :------: |
68+
| [DSVT](./configs/dsvt_voxel032_res-second_secfpn_8xb1-cyclic-12e_waymoD5-3d-3class.py) | [ResSECOND](./configs/dsvt_voxel032_res-second_secfpn_8xb1-cyclic-12e_waymoD5-3d-3class.py) | 5 | voxel (0.32) || × | | | 75.2 | 72.2 | 68.9 | 66.1 | |
69+
70+
**Note** that `ResSECOND` denotes the base block in SECOND has residual layers.
71+
72+
## Citation
73+
74+
```latex
75+
@inproceedings{wang2023dsvt,
76+
title={DSVT: Dynamic Sparse Voxel Transformer with Rotated Sets},
77+
author={Haiyang Wang, Chen Shi, Shaoshuai Shi, Meng Lei, Sen Wang, Di He, Bernt Schiele and Liwei Wang},
78+
booktitle={CVPR},
79+
year={2023}
80+
}
81+
```
Lines changed: 239 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,239 @@
1+
_base_ = ['../../../configs/_base_/default_runtime.py']
2+
custom_imports = dict(
3+
imports=['projects.DSVT.dsvt'], allow_failed_imports=False)
4+
5+
voxel_size = [0.32, 0.32, 6]
6+
grid_size = [468, 468, 1]
7+
point_cloud_range = [-74.88, -74.88, -2, 74.88, 74.88, 4.0]
8+
data_root = 'data/waymo/kitti_format/'
9+
class_names = ['Car', 'Pedestrian', 'Cyclist']
10+
metainfo = dict(classes=class_names)
11+
input_modality = dict(use_lidar=True, use_camera=False)
12+
backend_args = None
13+
14+
model = dict(
15+
type='DSVT',
16+
data_preprocessor=dict(type='Det3DDataPreprocessor', voxel=False),
17+
voxel_encoder=dict(
18+
type='DynamicPillarVFE3D',
19+
with_distance=False,
20+
use_absolute_xyz=True,
21+
use_norm=True,
22+
num_filters=[192, 192],
23+
num_point_features=5,
24+
voxel_size=voxel_size,
25+
grid_size=grid_size,
26+
point_cloud_range=point_cloud_range),
27+
middle_encoder=dict(
28+
type='DSVTMiddleEncoder',
29+
input_layer=dict(
30+
sparse_shape=grid_size,
31+
downsample_stride=[],
32+
dim_model=[192],
33+
set_info=[[36, 4]],
34+
window_shape=[[12, 12, 1]],
35+
hybrid_factor=[2, 2, 1], # x, y, z
36+
shift_list=[[[0, 0, 0], [6, 6, 0]]],
37+
normalize_pos=False),
38+
set_info=[[36, 4]],
39+
dim_model=[192],
40+
dim_feedforward=[384],
41+
stage_num=1,
42+
nhead=[8],
43+
conv_out_channel=192,
44+
output_shape=[468, 468],
45+
dropout=0.,
46+
activation='gelu'),
47+
map2bev=dict(
48+
type='PointPillarsScatter3D',
49+
output_shape=grid_size,
50+
num_bev_feats=192),
51+
backbone=dict(
52+
type='ResSECOND',
53+
in_channels=192,
54+
out_channels=[128, 128, 256],
55+
blocks_nums=[1, 2, 2],
56+
layer_strides=[1, 2, 2]),
57+
neck=dict(
58+
type='SECONDFPN',
59+
in_channels=[128, 128, 256],
60+
out_channels=[128, 128, 128],
61+
upsample_strides=[1, 2, 4],
62+
norm_cfg=dict(type='BN', eps=1e-3, momentum=0.01),
63+
upsample_cfg=dict(type='deconv', bias=False),
64+
use_conv_for_no_stride=False),
65+
bbox_head=dict(
66+
type='DSVTCenterHead',
67+
in_channels=sum([128, 128, 128]),
68+
tasks=[dict(num_class=3, class_names=class_names)],
69+
common_heads=dict(
70+
reg=(2, 2), height=(1, 2), dim=(3, 2), rot=(2, 2), iou=(1, 2)),
71+
share_conv_channel=64,
72+
conv_cfg=dict(type='Conv2d'),
73+
norm_cfg=dict(type='BN2d', eps=1e-3, momentum=0.01),
74+
bbox_coder=dict(
75+
type='DSVTBBoxCoder',
76+
pc_range=point_cloud_range,
77+
max_num=500,
78+
post_center_range=[-80, -80, -10.0, 80, 80, 10.0],
79+
score_threshold=0.1,
80+
out_size_factor=1,
81+
voxel_size=voxel_size[:2],
82+
code_size=7),
83+
separate_head=dict(
84+
type='SeparateHead',
85+
init_bias=-2.19,
86+
final_kernel=3,
87+
norm_cfg=dict(type='BN2d', eps=1e-3, momentum=0.01)),
88+
loss_cls=dict(
89+
type='mmdet.GaussianFocalLoss', reduction='mean', loss_weight=1.0),
90+
loss_bbox=dict(type='mmdet.L1Loss', reduction='mean', loss_weight=2.0),
91+
norm_bbox=True),
92+
# model training and testing settings
93+
train_cfg=dict(
94+
pts=dict(
95+
grid_size=grid_size,
96+
voxel_size=voxel_size,
97+
out_size_factor=4,
98+
dense_reg=1,
99+
gaussian_overlap=0.1,
100+
max_objs=500,
101+
min_radius=2,
102+
code_weights=[1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0])),
103+
test_cfg=dict(
104+
max_per_img=500,
105+
max_pool_nms=False,
106+
min_radius=[4, 12, 10, 1, 0.85, 0.175],
107+
iou_rectifier=[[0.68, 0.71, 0.65]],
108+
pc_range=[-80, -80],
109+
out_size_factor=4,
110+
voxel_size=voxel_size[:2],
111+
nms_type='rotate',
112+
multi_class_nms=True,
113+
pre_max_size=[[4096, 4096, 4096]],
114+
post_max_size=[[500, 500, 500]],
115+
nms_thr=[[0.7, 0.6, 0.55]]))
116+
117+
db_sampler = dict(
118+
data_root=data_root,
119+
info_path=data_root + 'waymo_dbinfos_train.pkl',
120+
rate=1.0,
121+
prepare=dict(
122+
filter_by_difficulty=[-1],
123+
filter_by_min_points=dict(Car=5, Pedestrian=5, Cyclist=5)),
124+
classes=class_names,
125+
sample_groups=dict(Car=15, Pedestrian=10, Cyclist=10),
126+
points_loader=dict(
127+
type='LoadPointsFromFile',
128+
coord_type='LIDAR',
129+
load_dim=6,
130+
use_dim=[0, 1, 2, 3, 4],
131+
backend_args=backend_args),
132+
backend_args=backend_args)
133+
134+
train_pipeline = [
135+
dict(
136+
type='LoadPointsFromFile',
137+
coord_type='LIDAR',
138+
load_dim=6,
139+
use_dim=5,
140+
norm_intensity=True,
141+
backend_args=backend_args),
142+
# Add this if using `MultiFrameDeformableDecoderRPN`
143+
# dict(
144+
# type='LoadPointsFromMultiSweeps',
145+
# sweeps_num=9,
146+
# load_dim=6,
147+
# use_dim=[0, 1, 2, 3, 4],
148+
# pad_empty_sweeps=True,
149+
# remove_close=True),
150+
dict(type='LoadAnnotations3D', with_bbox_3d=True, with_label_3d=True),
151+
dict(type='ObjectSample', db_sampler=db_sampler),
152+
dict(
153+
type='GlobalRotScaleTrans',
154+
rot_range=[-0.78539816, 0.78539816],
155+
scale_ratio_range=[0.95, 1.05],
156+
translation_std=[0.5, 0.5, 0]),
157+
dict(type='PointsRangeFilter', point_cloud_range=point_cloud_range),
158+
dict(type='ObjectRangeFilter', point_cloud_range=point_cloud_range),
159+
dict(type='ObjectNameFilter', classes=class_names),
160+
dict(type='PointShuffle'),
161+
dict(
162+
type='Pack3DDetInputs',
163+
keys=['points', 'gt_bboxes_3d', 'gt_labels_3d'])
164+
]
165+
166+
test_pipeline = [
167+
dict(
168+
type='LoadPointsFromFile',
169+
coord_type='LIDAR',
170+
load_dim=6,
171+
use_dim=5,
172+
norm_intensity=True,
173+
norm_elongation=True,
174+
backend_args=backend_args),
175+
dict(
176+
type='MultiScaleFlipAug3D',
177+
img_scale=(1333, 800),
178+
pts_scale_ratio=1,
179+
flip=False,
180+
transforms=[
181+
dict(
182+
type='GlobalRotScaleTrans',
183+
rot_range=[0, 0],
184+
scale_ratio_range=[1., 1.],
185+
translation_std=[0, 0, 0]),
186+
dict(type='RandomFlip3D'),
187+
dict(
188+
type='PointsRangeFilter', point_cloud_range=point_cloud_range)
189+
]),
190+
dict(type='Pack3DDetInputs', keys=['points'])
191+
]
192+
193+
dataset_type = 'WaymoDataset'
194+
val_dataloader = dict(
195+
batch_size=4,
196+
num_workers=4,
197+
persistent_workers=True,
198+
drop_last=False,
199+
sampler=dict(type='DefaultSampler', shuffle=False),
200+
dataset=dict(
201+
type=dataset_type,
202+
data_root=data_root,
203+
data_prefix=dict(pts='training/velodyne', sweeps='training/velodyne'),
204+
ann_file='waymo_infos_val.pkl',
205+
pipeline=test_pipeline,
206+
modality=input_modality,
207+
test_mode=True,
208+
metainfo=metainfo,
209+
box_type_3d='LiDAR',
210+
backend_args=backend_args))
211+
test_dataloader = val_dataloader
212+
213+
val_evaluator = dict(
214+
type='WaymoMetric',
215+
ann_file='./data/waymo/kitti_format/waymo_infos_val.pkl',
216+
waymo_bin_file='./data/waymo/waymo_format/gt.bin',
217+
data_root='./data/waymo/waymo_format',
218+
backend_args=backend_args,
219+
convert_kitti_format=False,
220+
idx2metainfo='./data/waymo/waymo_format/idx2metainfo.pkl')
221+
test_evaluator = val_evaluator
222+
223+
vis_backends = [dict(type='LocalVisBackend')]
224+
visualizer = dict(
225+
type='Det3DLocalVisualizer', vis_backends=vis_backends, name='visualizer')
226+
227+
# runtime settings
228+
val_cfg = dict()
229+
test_cfg = dict()
230+
231+
# Default setting for scaling LR automatically
232+
# - `enable` means enable scaling LR automatically
233+
# or not by default.
234+
# - `base_batch_size` = (8 GPUs) x (1 samples per GPU).
235+
# auto_scale_lr = dict(enable=False, base_batch_size=8)
236+
237+
default_hooks = dict(
238+
logger=dict(type='LoggerHook', interval=50),
239+
checkpoint=dict(type='CheckpointHook', interval=5))

projects/DSVT/dsvt/__init__.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,12 @@
1+
from .dsvt import DSVT
2+
from .dsvt_head import DSVTCenterHead
3+
from .dsvt_transformer import DSVTMiddleEncoder
4+
from .dynamic_pillar_vfe import DynamicPillarVFE3D
5+
from .map2bev import PointPillarsScatter3D
6+
from .res_second import ResSECOND
7+
from .utils import DSVTBBoxCoder
8+
9+
__all__ = [
10+
'DSVTCenterHead', 'DSVT', 'DSVTMiddleEncoder', 'DynamicPillarVFE3D',
11+
'PointPillarsScatter3D', 'ResSECOND', 'DSVTBBoxCoder'
12+
]

0 commit comments

Comments
 (0)