Skip to content

Commit 02ac3e1

Browse files
authored
Support multi-modal 3D detection on NuScenes open-mmlab#1339
Add support for multi-modal NuScenes Detection
2 parents ad9c25c + fcfa077 commit 02ac3e1

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

41 files changed

+3863
-33
lines changed

README.md

Lines changed: 14 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@ It is also the official code release of [`[PointRCNN]`](https://arxiv.org/abs/18
1010
* `OpenPCDet` has been updated to `v0.6.0` (Sep. 2022).
1111
* The codes of PV-RCNN++ has been supported.
1212
* The codes of MPPNet has been supported.
13+
* The multi-modal 3D detection approaches on Nuscenes have been supported.
1314

1415
## Overview
1516
- [Changelog](#changelog)
@@ -22,10 +23,15 @@ It is also the official code release of [`[PointRCNN]`](https://arxiv.org/abs/18
2223

2324

2425
## Changelog
25-
[2023-04-02] Added support for [`VoxelNeXt`](https://github.com/dvlab-research/VoxelNeXt) on Nuscenes, Waymo, and Argoverse2 datasets. It is a fully sparse 3D object detection network, which is a clean sparse CNNs network and predicts 3D objects directly upon voxels.
26+
[2023-05-13] **NEW:** Added support for the multi-modal 3D object detection models on Nuscenes dataset.
27+
* Support multi-modal Nuscenes detection (See the [GETTING_STARTED.md](docs/GETTING_STARTED.md) to process data).
28+
* Support [TransFusion-Lidar](https://arxiv.org/abs/2203.11496) head, which ahcieves 69.43% NDS on Nuscenes validation dataset.
29+
* Support [`BEVFusion`](https://arxiv.org/abs/2205.13542), which fuses multi-modal information on BEV space and reaches 70.98% NDS on Nuscenes validation dataset. (see the [guideline](docs/guidelines_of_approaches/bevfusion.md) on how to train/test with BEVFusion).
30+
31+
[2023-04-02] Added support for [`VoxelNeXt`](https://arxiv.org/abs/2303.11301) on Nuscenes, Waymo, and Argoverse2 datasets. It is a fully sparse 3D object detection network, which is a clean sparse CNNs network and predicts 3D objects directly upon voxels.
2632

2733
[2022-09-02] **NEW:** Update `OpenPCDet` to v0.6.0:
28-
* Official code release of [MPPNet](https://arxiv.org/abs/2205.05979) for temporal 3D object detection, which supports long-term multi-frame 3D object detection and ranks 1st place on [3D detection learderboard](https://waymo.com/open/challenges/2020/3d-detection) of Waymo Open Dataset on Sept. 2th, 2022. For validation dataset, MPPNet achieves 74.96%, 75.06% and 74.52% for vehicle, pedestrian and cyclist classes in terms of mAPH@Level_2. (see the [guideline](docs/guidelines_of_approaches/mppnet.md) on how to train/test with MPPNet).
34+
* Official code release of [`MPPNet`](https://arxiv.org/abs/2205.05979) for temporal 3D object detection, which supports long-term multi-frame 3D object detection and ranks 1st place on [3D detection learderboard](https://waymo.com/open/challenges/2020/3d-detection) of Waymo Open Dataset on Sept. 2th, 2022. For validation dataset, MPPNet achieves 74.96%, 75.06% and 74.52% for vehicle, pedestrian and cyclist classes in terms of mAPH@Level_2. (see the [guideline](docs/guidelines_of_approaches/mppnet.md) on how to train/test with MPPNet).
2935
* Support multi-frame training/testing on Waymo Open Dataset (see the [change log](docs/changelog.md) for more details on how to process data).
3036
* Support to save changing training details (e.g., loss, iter, epoch) to file (previous tqdm progress bar is still supported by using `--use_tqdm_to_record`). Please use `pip install gpustat` if you also want to log the GPU related information.
3137
* Support to save latest model every 5 mintues, so you can restore the model training from latest status instead of previous epoch.
@@ -38,10 +44,10 @@ It is also the official code release of [`[PointRCNN]`](https://arxiv.org/abs/18
3844

3945
[2022-02-07] Added support for Centerpoint models on Nuscenes Dataset.
4046

41-
[2022-01-14] Added support for dynamic pillar voxelization, following the implementation proposed in [H^23D R-CNN](https://arxiv.org/abs/2107.14391) with unique operation and [`torch_scatter`](https://github.com/rusty1s/pytorch_scatter) package.
47+
[2022-01-14] Added support for dynamic pillar voxelization, following the implementation proposed in [`H^23D R-CNN`](https://arxiv.org/abs/2107.14391) with unique operation and [`torch_scatter`](https://github.com/rusty1s/pytorch_scatter) package.
4248

4349
[2022-01-05] **NEW:** Update `OpenPCDet` to v0.5.2:
44-
* The code of [PV-RCNN++](https://arxiv.org/abs/2102.00463) has been released to this repo, with higher performance, faster training/inference speed and less memory consumption than PV-RCNN.
50+
* The code of [`PV-RCNN++`](https://arxiv.org/abs/2102.00463) has been released to this repo, with higher performance, faster training/inference speed and less memory consumption than PV-RCNN.
4551
* Add performance of several models trained with full training set of [Waymo Open Dataset](#waymo-open-dataset-baselines).
4652
* Support Lyft dataset, see the pull request [here](https://github.com/open-mmlab/OpenPCDet/pull/720).
4753

@@ -199,7 +205,7 @@ We could not provide the above pretrained models due to [Waymo Dataset License A
199205
but you could easily achieve similar performance by training with the default configs.
200206

201207
### NuScenes 3D Object Detection Baselines
202-
All models are trained with 8 GTX 1080Ti GPUs and are available for download.
208+
All models are trained with 8 GPUs and are available for download. For training BEVFusion, please refer to the [guideline](docs/guidelines_of_approaches/bevfusion.md).
203209

204210
| | mATE | mASE | mAOE | mAVE | mAAE | mAP | NDS | download |
205211
|----------------------------------------------------------------------------------------------------|-------:|:------:|:------:|:-----:|:-----:|:-----:|:------:|:--------------------------------------------------------------------------------------------------:|
@@ -209,7 +215,10 @@ All models are trained with 8 GTX 1080Ti GPUs and are available for download.
209215
| [CenterPoint (voxel_size=0.1)](tools/cfgs/nuscenes_models/cbgs_voxel01_res3d_centerpoint.yaml) | 30.11 | 25.55 | 38.28 | 21.94 | 18.87 | 56.03 | 64.54 | [model-34M](https://drive.google.com/file/d/1Cz-J1c3dw7JAWc25KRG1XQj8yCaOlexQ/view?usp=sharing) |
210216
| [CenterPoint (voxel_size=0.075)](tools/cfgs/nuscenes_models/cbgs_voxel0075_res3d_centerpoint.yaml) | 28.80 | 25.43 | 37.27 | 21.55 | 18.24 | 59.22 | 66.48 | [model-34M](https://drive.google.com/file/d/1XOHAWm1MPkCKr1gqmc3TWi5AYZgPsgxU/view?usp=sharing) |
211217
| [VoxelNeXt (voxel_size=0.075)](tools/cfgs/nuscenes_models/cbgs_voxel0075_voxelnext.yaml) | 30.11 | 25.23 | 40.57 | 21.69 | 18.56 | 60.53 | 66.65 | [model-31M](https://drive.google.com/file/d/1IV7e7G9X-61KXSjMGtQo579pzDNbhwvf/view?usp=share_link) |
218+
| [TransFusion-L*](tools/cfgs/nuscenes_models/transfusion_lidar.yaml) | 27.96 | 25.37 | 29.35 | 27.31 | 18.55 | 64.58 | 69.43 | [model-32M](https://drive.google.com/file/d/1cuZ2qdDnxSwTCsiXWwbqCGF-uoazTXbz/view?usp=share_link) |
219+
| [BEVFusion](tools/cfgs/nuscenes_models/bevfusion.yaml) | 28.03 | 25.43 | 30.19 | 26.76 | 18.48 | 67.75 | 70.98 | [model-157M](https://drive.google.com/file/d/1X50b-8immqlqD8VPAUkSKI0Ls-4k37g9/view?usp=share_link) |
212220

221+
*: Use the fade strategy, which disables data augmentations in the last several epochs during training.
213222

214223
### ONCE 3D Object Detection Baselines
215224
All models are trained with 8 GPUs.

docs/GETTING_STARTED.md

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -53,9 +53,16 @@ pip install nuscenes-devkit==1.0.5
5353

5454
* Generate the data infos by running the following command (it may take several hours):
5555
```python
56+
# for lidar-only setting
5657
python -m pcdet.datasets.nuscenes.nuscenes_dataset --func create_nuscenes_infos \
5758
--cfg_file tools/cfgs/dataset_configs/nuscenes_dataset.yaml \
5859
--version v1.0-trainval
60+
61+
# for multi-modal setting
62+
python -m pcdet.datasets.nuscenes.nuscenes_dataset --func create_nuscenes_infos \
63+
--cfg_file tools/cfgs/dataset_configs/nuscenes_dataset.yaml \
64+
--version v1.0-trainval \
65+
--with_cam
5966
```
6067

6168
### Waymo Open Dataset
Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,35 @@
1+
2+
## Installation
3+
4+
Please refer to [INSTALL.md](../INSTALL.md) for the installation of `OpenPCDet`.
5+
* We recommend the users to check the version of pillow and use pillow==8.4.0 to avoid bug in bev pooling.
6+
7+
## Data Preparation
8+
Please refer to [GETTING_STARTED.md](../GETTING_STARTED.md) to process the multi-modal Nuscenes Dataset.
9+
10+
## Training
11+
12+
1. Train the lidar branch for BEVFusion:
13+
```shell
14+
bash scripts/dist_train.sh ${NUM_GPUS} --cfg_file cfgs/nuscenes_models/transfusion_lidar.yaml \
15+
```
16+
The ckpt will be saved in ../output/nuscenes_models/transfusion_lidar/default/ckpt, or you can download pretrained checkpoint directly form [here](https://drive.google.com/file/d/1cuZ2qdDnxSwTCsiXWwbqCGF-uoazTXbz/view?usp=share_link).
17+
18+
2. To train BEVFusion, you need to download pretrained parameters for image backbone [here](https://drive.google.com/file/d/1v74WCt4_5ubjO7PciA5T0xhQc9bz_jZu/view?usp=share_link), and specify the path in [config](../../tools/cfgs/nuscenes_models/bevfusion.yaml#L88). Then run the following command:
19+
```shell
20+
bash scripts/dist_train.sh ${NUM_GPUS} --cfg_file cfgs/nuscenes_models/bevfusion.yaml \
21+
--pretrained_model path_to_pretrained_lidar_branch_ckpt \
22+
```
23+
## Evaluation
24+
* Test with a pretrained model:
25+
```shell
26+
bash scripts/dist_test.sh ${NUM_GPUS} --cfg_file cfgs/nuscenes_models/bevfusion.yaml \
27+
--ckpt ../output/cfgs/nuscenes_models/bevfusion/default/ckpt/checkpoint_epoch_6.pth
28+
```
29+
30+
## Performance
31+
All models are trained with spconv 1.0, but you can directly load them for testing regardless of the spconv version.
32+
| | mATE | mASE | mAOE | mAVE | mAAE | mAP | NDS | download |
33+
|----------------------------------------------------------------------------------------------------|-------:|:------:|:------:|:-----:|:-----:|:-----:|:------:|:--------------------------------------------------------------------------------------------------:|
34+
| [TransFusion-L](../../tools/cfgs/nuscenes_models/transfusion_lidar.yaml) | 27.96 | 25.37 | 29.35 | 27.31 | 18.55 | 64.58 | 69.43 | [model-32M](https://drive.google.com/file/d/1cuZ2qdDnxSwTCsiXWwbqCGF-uoazTXbz/view?usp=share_link) |
35+
| [BEVFusion](../../tools/cfgs/nuscenes_models/bevfusion.yaml) | 28.03 | 25.43 | 30.19 | 26.76 | 18.48 | 67.75 | 70.98 | [model-157M](https://drive.google.com/file/d/1X50b-8immqlqD8VPAUkSKI0Ls-4k37g9/view?usp=share_link) |

pcdet/datasets/augmentor/data_augmentor.py

Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
from functools import partial
22

33
import numpy as np
4+
from PIL import Image
45

56
from ...utils import common_utils
67
from . import augmentor_utils, database_sampler
@@ -23,6 +24,18 @@ def __init__(self, root_path, augmentor_configs, class_names, logger=None):
2324
cur_augmentor = getattr(self, cur_cfg.NAME)(config=cur_cfg)
2425
self.data_augmentor_queue.append(cur_augmentor)
2526

27+
def disable_augmentation(self, augmentor_configs):
28+
self.data_augmentor_queue = []
29+
aug_config_list = augmentor_configs if isinstance(augmentor_configs, list) \
30+
else augmentor_configs.AUG_CONFIG_LIST
31+
32+
for cur_cfg in aug_config_list:
33+
if not isinstance(augmentor_configs, list):
34+
if cur_cfg.NAME in augmentor_configs.DISABLE_AUG_LIST:
35+
continue
36+
cur_augmentor = getattr(self, cur_cfg.NAME)(config=cur_cfg)
37+
self.data_augmentor_queue.append(cur_augmentor)
38+
2639
def gt_sampling(self, config=None):
2740
db_sampler = database_sampler.DataBaseSampler(
2841
root_path=self.root_path,
@@ -139,6 +152,7 @@ def random_world_translation(self, data_dict=None, config=None):
139152

140153
data_dict['gt_boxes'] = gt_boxes
141154
data_dict['points'] = points
155+
data_dict['noise_translate'] = noise_translate
142156
return data_dict
143157

144158
def random_local_translation(self, data_dict=None, config=None):
@@ -251,6 +265,28 @@ def random_local_pyramid_aug(self, data_dict=None, config=None):
251265
data_dict['points'] = points
252266
return data_dict
253267

268+
def imgaug(self, data_dict=None, config=None):
269+
if data_dict is None:
270+
return partial(self.imgaug, config=config)
271+
imgs = data_dict["camera_imgs"]
272+
img_process_infos = data_dict['img_process_infos']
273+
new_imgs = []
274+
for img, img_process_info in zip(imgs, img_process_infos):
275+
flip = False
276+
if config.RAND_FLIP and np.random.choice([0, 1]):
277+
flip = True
278+
rotate = np.random.uniform(*config.ROT_LIM)
279+
# aug images
280+
if flip:
281+
img = img.transpose(method=Image.FLIP_LEFT_RIGHT)
282+
img = img.rotate(rotate)
283+
img_process_info[2] = flip
284+
img_process_info[3] = rotate
285+
new_imgs.append(img)
286+
287+
data_dict["camera_imgs"] = new_imgs
288+
return data_dict
289+
254290
def forward(self, data_dict):
255291
"""
256292
Args:

pcdet/datasets/dataset.py

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
from pathlib import Path
33

44
import numpy as np
5+
import torch
56
import torch.utils.data as torch_data
67

78
from ..utils import common_utils
@@ -130,6 +131,30 @@ def __getitem__(self, index):
130131
"""
131132
raise NotImplementedError
132133

134+
def set_lidar_aug_matrix(self, data_dict):
135+
"""
136+
Get lidar augment matrix (4 x 4), which are used to recover orig point coordinates.
137+
"""
138+
lidar_aug_matrix = np.eye(4)
139+
if 'flip_y' in data_dict.keys():
140+
flip_x = data_dict['flip_x']
141+
flip_y = data_dict['flip_y']
142+
if flip_x:
143+
lidar_aug_matrix[:3,:3] = np.array([[1, 0, 0], [0, -1, 0], [0, 0, 1]]) @ lidar_aug_matrix[:3,:3]
144+
if flip_y:
145+
lidar_aug_matrix[:3,:3] = np.array([[-1, 0, 0], [0, 1, 0], [0, 0, 1]]) @ lidar_aug_matrix[:3,:3]
146+
if 'noise_rot' in data_dict.keys():
147+
noise_rot = data_dict['noise_rot']
148+
lidar_aug_matrix[:3,:3] = common_utils.angle2matrix(torch.tensor(noise_rot)) @ lidar_aug_matrix[:3,:3]
149+
if 'noise_scale' in data_dict.keys():
150+
noise_scale = data_dict['noise_scale']
151+
lidar_aug_matrix[:3,:3] *= noise_scale
152+
if 'noise_translate' in data_dict.keys():
153+
noise_translate = data_dict['noise_translate']
154+
lidar_aug_matrix[:3,3:4] = noise_translate.T
155+
data_dict['lidar_aug_matrix'] = lidar_aug_matrix
156+
return data_dict
157+
133158
def prepare_data(self, data_dict):
134159
"""
135160
Args:
@@ -165,6 +190,7 @@ def prepare_data(self, data_dict):
165190
)
166191
if 'calib' in data_dict:
167192
data_dict['calib'] = calib
193+
data_dict = self.set_lidar_aug_matrix(data_dict)
168194
if data_dict.get('gt_boxes', None) is not None:
169195
selected = common_utils.keep_arrays_by_name(data_dict['gt_names'], self.class_names)
170196
data_dict['gt_boxes'] = data_dict['gt_boxes'][selected]
@@ -287,6 +313,8 @@ def collate_batch(batch_list, _unused=False):
287313
constant_values=pad_value)
288314
points.append(points_pad)
289315
ret[key] = np.stack(points, axis=0)
316+
elif key in ['camera_imgs']:
317+
ret[key] = torch.stack([torch.stack(imgs,dim=0) for imgs in val],dim=0)
290318
else:
291319
ret[key] = np.stack(val, axis=0)
292320
except:

0 commit comments

Comments
 (0)