Skip to content

Commit ad9c25c

Browse files
authored
Some detailed modifications for Argo2 and VoxelNeXt (open-mmlab#1327)
* Add files via upload * Delete cbgs_voxel01_voxelnext_headkernel3.yaml * Delete voxelnext_ioubranch.yaml
1 parent 81763e7 commit ad9c25c

File tree

6 files changed

+245
-150
lines changed

6 files changed

+245
-150
lines changed

README.md

Lines changed: 13 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -176,18 +176,19 @@ By default, all models are trained with **a single frame** of **20% data (~32k f
176176

177177
Here we also provide the performance of several models trained on the full training set (refer to the paper of [PV-RCNN++](https://arxiv.org/abs/2102.00463)):
178178

179-
| Performance@(train with 100\% Data) | Vec_L1 | Vec_L2 | Ped_L1 | Ped_L2 | Cyc_L1 | Cyc_L2 |
180-
|---------------------------------------------|----------:|:-------:|:-------:|:-------:|:-------:|:-------:|
181-
| [SECOND](tools/cfgs/waymo_models/second.yaml) | 72.27/71.69 | 63.85/63.33 | 68.70/58.18 | 60.72/51.31 | 60.62/59.28 | 58.34/57.05 |
182-
| [CenterPoint-Pillar](tools/cfgs/waymo_models/centerpoint_pillar_1x.yaml)| 73.37/72.86 | 65.09/64.62 | 75.35/65.11 | 67.61/58.25 | 67.76/66.22 | 65.25/63.77 |
183-
| [Part-A2-Anchor](tools/cfgs/waymo_models/PartA2.yaml) | 77.05/76.51 | 68.47/67.97 | 75.24/66.87 | 66.18/58.62 | 68.60/67.36 | 66.13/64.93 |
184-
| [VoxelNeXt-2D](tools/cfgs/waymo_models/voxelnext2d_ioubranch.yaml) | 77.94/77.47 |69.68/69.25 |80.24/73.47 |72.23/65.88 |73.33/72.20 |70.66/69.56 |
185-
| [PV-RCNN (CenterHead)](tools/cfgs/waymo_models/pv_rcnn_with_centerhead_rpn.yaml) | 78.00/77.50 | 69.43/68.98 | 79.21/73.03 | 70.42/64.72 | 71.46/70.27 | 68.95/67.79 |
186-
| [PV-RCNN++](tools/cfgs/waymo_models/pv_rcnn_plusplus.yaml) | 79.10/78.63 | 70.34/69.91 | 80.62/74.62 | 71.86/66.30 | 73.49/72.38 | 70.70/69.62 |
187-
| [PV-RCNN++ (ResNet)](tools/cfgs/waymo_models/pv_rcnn_plusplus_resnet.yaml) | 79.25/78.78 | 70.61/70.18 | 81.83/76.28 | 73.17/68.00 | 73.72/72.66 | 71.21/70.19 |
179+
| Performance@(train with 100\% Data) | Vec_L1 | Vec_L2 | Ped_L1 | Ped_L2 | Cyc_L1 | Cyc_L2 |
180+
|-------------------------------------------------------------------------------------------|----------:|:-------:|:-------:|:-------:|:-------:|:-------:|
181+
| [SECOND](tools/cfgs/waymo_models/second.yaml) | 72.27/71.69 | 63.85/63.33 | 68.70/58.18 | 60.72/51.31 | 60.62/59.28 | 58.34/57.05 |
182+
| [CenterPoint-Pillar](tools/cfgs/waymo_models/centerpoint_pillar_1x.yaml) | 73.37/72.86 | 65.09/64.62 | 75.35/65.11 | 67.61/58.25 | 67.76/66.22 | 65.25/63.77 |
183+
| [Part-A2-Anchor](tools/cfgs/waymo_models/PartA2.yaml) | 77.05/76.51 | 68.47/67.97 | 75.24/66.87 | 66.18/58.62 | 68.60/67.36 | 66.13/64.93 |
184+
| [VoxelNeXt-2D](tools/cfgs/waymo_models/voxelnext2d_ioubranch.yaml) | 77.94/77.47 |69.68/69.25 |80.24/73.47 |72.23/65.88 |73.33/72.20 |70.66/69.56 |
185+
| [VoxelNeXt](tools/cfgs/waymo_models/voxelnext_ioubranch_large.yaml) | 78.16/77.70 |69.86/69.42 |81.47/76.30 |73.48/68.63 |76.06/74.90 |73.29/72.18 |
186+
| [PV-RCNN (CenterHead)](tools/cfgs/waymo_models/pv_rcnn_with_centerhead_rpn.yaml) | 78.00/77.50 | 69.43/68.98 | 79.21/73.03 | 70.42/64.72 | 71.46/70.27 | 68.95/67.79 |
187+
| [PV-RCNN++](tools/cfgs/waymo_models/pv_rcnn_plusplus.yaml) | 79.10/78.63 | 70.34/69.91 | 80.62/74.62 | 71.86/66.30 | 73.49/72.38 | 70.70/69.62 |
188+
| [PV-RCNN++ (ResNet)](tools/cfgs/waymo_models/pv_rcnn_plusplus_resnet.yaml) | 79.25/78.78 | 70.61/70.18 | 81.83/76.28 | 73.17/68.00 | 73.72/72.66 | 71.21/70.19 |
188189
| [PV-RCNN++ (ResNet, 2 frames)](tools/cfgs/waymo_models/pv_rcnn_plusplus_resnet_2frames.yaml) | 80.17/79.70 | 72.14/71.70 | 83.48/80.42 | 75.54/72.61 | 74.63/73.75 | 72.35/71.50 |
189-
| [MPPNet (4 frames)](docs/guidelines_of_approaches/mppnet.md) | 81.54/81.06 | 74.07/73.61 | 84.56/81.94 | 77.20/74.67 | 77.15/76.50 | 75.01/74.38 |
190-
| [MPPNet (16 frames)](docs/guidelines_of_approaches/mppnet.md) | 82.74/82.28 | 75.41/74.96 | 84.69/82.25 | 77.43/75.06 | 77.28/76.66 | 75.13/74.52 |
190+
| [MPPNet (4 frames)](docs/guidelines_of_approaches/mppnet.md) | 81.54/81.06 | 74.07/73.61 | 84.56/81.94 | 77.20/74.67 | 77.15/76.50 | 75.01/74.38 |
191+
| [MPPNet (16 frames)](docs/guidelines_of_approaches/mppnet.md) | 82.74/82.28 | 75.41/74.96 | 84.69/82.25 | 77.43/75.06 | 77.28/76.66 | 75.13/74.52 |
191192

192193

193194

@@ -226,8 +227,7 @@ All models are trained with 4 GPUs.
226227

227228
| | mAP | download |
228229
|---------------------------------------------------------|:----:|:--------------------------------------------------------------------------------------------------:|
229-
| [VoxelNeXt](tools/cfgs/argo2_models/cbgs_voxel01_voxelnext.yaml) | 30.0 | [model-30M](https://drive.google.com/file/d/1zr-it1ERJzLQ3a3hP060z_EQqS_RkNaC/view?usp=share_link) |
230-
| [VoxelNeXt-K3](tools/cfgs/argo2_models/cbgs_voxel01_voxelnext_headkernel3.yaml) | 30.7 | [model-45M](https://drive.google.com/file/d/1NrYRsiKbuWyL8jE4SY27IHpFMY9K0o__/view?usp=share_link) |
230+
| [VoxelNeXt](tools/cfgs/argo2_models/cbgs_voxel01_voxelnext.yaml) | 30.5 | [model-32M](https://drive.google.com/file/d/1YP2UOz-yO-cWfYQkIqILEu6bodvCBVrR/view?usp=share_link) |
231231

232232
### Other datasets
233233
Welcome to support other datasets by submitting pull request.

pcdet/datasets/argo2/argo2_dataset.py

Lines changed: 6 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -488,39 +488,23 @@ def parse_config():
488488
parser = argparse.ArgumentParser(description='arg parser')
489489
parser.add_argument('--root_path', type=str, default="/data/argo2/sensor")
490490
parser.add_argument('--output_dir', type=str, default="/data/argo2/processed")
491-
parser.add_argument('--num_process', type=int, default=16)
492491
args = parser.parse_args()
493492
return args
494493

495-
def main(seg_path_list, seg_split_list, info_list, ts2idx, output_dir, save_bin, token, num_process):
496-
for seg_i, seg_path in enumerate(seg_path_list):
497-
if seg_i % num_process != token:
498-
continue
499-
print(f'processing segment: {seg_i}/{len(seg_path_list)}')
500-
split = seg_split_list[seg_i]
501-
process_single_segment(seg_path, split, info_list, ts2idx, output_dir, save_bin)
502494

503495
if __name__ == '__main__':
504496
args = parse_config()
505497
root = args.root_path
506498
output_dir = args.output_dir
507-
num_process = args.num_process
508499
save_bin = True
509500
ts2idx, seg_path_list, seg_split_list = prepare(root)
510501

511-
if num_process > 1:
512-
with mp.Manager() as manager:
513-
info_list = manager.list()
514-
pool = mp.Pool(num_process)
515-
for token in range(num_process):
516-
result = pool.apply_async(main, args=(
517-
seg_path_list, seg_split_list, info_list, ts2idx, output_dir, save_bin, token, num_process))
518-
pool.close()
519-
pool.join()
520-
info_list = list(info_list)
521-
else:
522-
info_list = []
523-
main(seg_path_list, seg_split_list, info_list, ts2idx, output_dir, save_bin, 0, 1)
502+
velodyne_dir = Path(output_dir) / 'training' / 'velodyne'
503+
if not velodyne_dir.exists():
504+
velodyne_dir.mkdir(parents=True, exist_ok=True)
505+
506+
info_list = []
507+
create_argo2_infos(seg_path_list, seg_split_list, info_list, ts2idx, output_dir, save_bin, 0, 1)
524508

525509
assert len(info_list) > 0
526510

@@ -551,4 +535,3 @@ def main(seg_path_list, seg_split_list, info_list, ts2idx, output_dir, save_bin,
551535

552536
gts = pd.concat(seg_anno_list).reset_index()
553537
gts.to_feather(save_feather_path)
554-

pcdet/models/backbones_3d/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
from .spconv_backbone_2d import PillarBackBone8x, PillarRes18BackBone8x
44
from .spconv_backbone_focal import VoxelBackBone8xFocal
55
from .spconv_backbone_voxelnext import VoxelResBackBone8xVoxelNeXt
6+
from .spconv_backbone_voxelnext2d import VoxelResBackBone8xVoxelNeXt2D
67
from .spconv_unet import UNetV2
78

89
__all__ = {
@@ -13,6 +14,7 @@
1314
'VoxelResBackBone8x': VoxelResBackBone8x,
1415
'VoxelBackBone8xFocal': VoxelBackBone8xFocal,
1516
'VoxelResBackBone8xVoxelNeXt': VoxelResBackBone8xVoxelNeXt,
17+
'VoxelResBackBone8xVoxelNeXt2D': VoxelResBackBone8xVoxelNeXt2D,
1618
'PillarBackBone8x': PillarBackBone8x,
1719
'PillarRes18BackBone8x': PillarRes18BackBone8x
1820
}
Lines changed: 219 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,219 @@
1+
from functools import partial
2+
import torch
3+
import torch.nn as nn
4+
5+
from ...utils.spconv_utils import replace_feature, spconv
6+
7+
8+
def post_act_block(in_channels, out_channels, kernel_size, indice_key=None, stride=1, padding=0,
9+
conv_type='subm', norm_fn=None):
10+
11+
if conv_type == 'subm':
12+
conv = spconv.SubMConv2d(in_channels, out_channels, kernel_size, bias=False, indice_key=indice_key)
13+
elif conv_type == 'spconv':
14+
conv = spconv.SparseConv2d(in_channels, out_channels, kernel_size, stride=stride, padding=padding,
15+
bias=False, indice_key=indice_key)
16+
elif conv_type == 'inverseconv':
17+
conv = spconv.SparseInverseConv2d(in_channels, out_channels, kernel_size, indice_key=indice_key, bias=False)
18+
else:
19+
raise NotImplementedError
20+
21+
m = spconv.SparseSequential(
22+
conv,
23+
norm_fn(out_channels),
24+
nn.ReLU(),
25+
)
26+
27+
return m
28+
29+
30+
class SparseBasicBlock(spconv.SparseModule):
31+
expansion = 1
32+
33+
def __init__(self, inplanes, planes, stride=1, norm_fn=None, downsample=None, indice_key=None):
34+
super(SparseBasicBlock, self).__init__()
35+
36+
assert norm_fn is not None
37+
bias = norm_fn is not None
38+
self.conv1 = spconv.SubMConv2d(
39+
inplanes, planes, kernel_size=3, stride=stride, padding=1, bias=bias, indice_key=indice_key
40+
)
41+
self.bn1 = norm_fn(planes)
42+
self.relu = nn.ReLU()
43+
self.conv2 = spconv.SubMConv2d(
44+
planes, planes, kernel_size=3, stride=stride, padding=1, bias=bias, indice_key=indice_key
45+
)
46+
self.bn2 = norm_fn(planes)
47+
self.downsample = downsample
48+
self.stride = stride
49+
50+
def forward(self, x):
51+
identity = x
52+
53+
out = self.conv1(x)
54+
out = replace_feature(out, self.bn1(out.features))
55+
out = replace_feature(out, self.relu(out.features))
56+
57+
out = self.conv2(out)
58+
out = replace_feature(out, self.bn2(out.features))
59+
60+
if self.downsample is not None:
61+
identity = self.downsample(x)
62+
63+
out = replace_feature(out, out.features + identity.features)
64+
out = replace_feature(out, self.relu(out.features))
65+
66+
return out
67+
68+
69+
class VoxelResBackBone8xVoxelNeXt2D(nn.Module):
70+
def __init__(self, model_cfg, input_channels, grid_size, **kwargs):
71+
super().__init__()
72+
self.model_cfg = model_cfg
73+
norm_fn = partial(nn.BatchNorm1d, eps=1e-3, momentum=0.01)
74+
self.sparse_shape = grid_size[[1, 0]]
75+
76+
block = post_act_block
77+
78+
spconv_kernel_sizes = model_cfg.get('SPCONV_KERNEL_SIZES', [3, 3, 3, 3])
79+
80+
self.conv1 = spconv.SparseSequential(
81+
SparseBasicBlock(32, 32, norm_fn=norm_fn, indice_key='res1'),
82+
SparseBasicBlock(32, 32, norm_fn=norm_fn, indice_key='res1'),
83+
SparseBasicBlock(32, 32, norm_fn=norm_fn, indice_key='res1'),
84+
)
85+
86+
self.conv2 = spconv.SparseSequential(
87+
# [1600, 1408] <- [800, 704]
88+
block(32, 64, spconv_kernel_sizes[0], norm_fn=norm_fn, stride=2, padding=int(spconv_kernel_sizes[0]//2), indice_key='spconv2', conv_type='spconv'),
89+
SparseBasicBlock(64, 64, norm_fn=norm_fn, indice_key='res2'),
90+
SparseBasicBlock(64, 64, norm_fn=norm_fn, indice_key='res2'),
91+
SparseBasicBlock(64, 64, norm_fn=norm_fn, indice_key='res2'),
92+
SparseBasicBlock(64, 64, norm_fn=norm_fn, indice_key='res2'),
93+
)
94+
95+
self.conv3 = spconv.SparseSequential(
96+
# [800, 704] <- [400, 352]
97+
block(64, 128, spconv_kernel_sizes[1], norm_fn=norm_fn, stride=2, padding=int(spconv_kernel_sizes[1]//2), indice_key='spconv3', conv_type='spconv'),
98+
SparseBasicBlock(128, 128, norm_fn=norm_fn, indice_key='res3'),
99+
SparseBasicBlock(128, 128, norm_fn=norm_fn, indice_key='res3'),
100+
SparseBasicBlock(128, 128, norm_fn=norm_fn, indice_key='res3'),
101+
SparseBasicBlock(128, 128, norm_fn=norm_fn, indice_key='res3'),
102+
SparseBasicBlock(128, 128, norm_fn=norm_fn, indice_key='res3'),
103+
SparseBasicBlock(128, 128, norm_fn=norm_fn, indice_key='res3'),
104+
)
105+
106+
self.conv4 = spconv.SparseSequential(
107+
# [400, 352] <- [200, 176]
108+
block(128, 256, spconv_kernel_sizes[2], norm_fn=norm_fn, stride=2, padding=int(spconv_kernel_sizes[2]//2), indice_key='spconv4', conv_type='spconv'),
109+
SparseBasicBlock(256, 256, norm_fn=norm_fn, indice_key='res4'),
110+
SparseBasicBlock(256, 256, norm_fn=norm_fn, indice_key='res4'),
111+
SparseBasicBlock(256, 256, norm_fn=norm_fn, indice_key='res4'),
112+
)
113+
114+
self.conv5 = spconv.SparseSequential(
115+
# [400, 352] <- [200, 176]
116+
block(256, 256, spconv_kernel_sizes[3], norm_fn=norm_fn, stride=2, padding=int(spconv_kernel_sizes[3]//2), indice_key='spconv5', conv_type='spconv'),
117+
SparseBasicBlock(256, 256, norm_fn=norm_fn, indice_key='res5'),
118+
SparseBasicBlock(256, 256, norm_fn=norm_fn, indice_key='res5'),
119+
SparseBasicBlock(256, 256, norm_fn=norm_fn, indice_key='res5'),
120+
)
121+
122+
self.conv6 = spconv.SparseSequential(
123+
# [400, 352] <- [200, 176]
124+
block(256, 256, spconv_kernel_sizes[3], norm_fn=norm_fn, stride=2, padding=int(spconv_kernel_sizes[3]//2), indice_key='spconv6', conv_type='spconv'),
125+
SparseBasicBlock(256, 256, norm_fn=norm_fn, indice_key='res6'),
126+
SparseBasicBlock(256, 256, norm_fn=norm_fn, indice_key='res6'),
127+
SparseBasicBlock(256, 256, norm_fn=norm_fn, indice_key='res6'),
128+
)
129+
130+
self.conv_out = spconv.SparseSequential(
131+
# [200, 150, 5] -> [200, 150, 2]
132+
spconv.SparseConv2d(256, 256, 3, stride=1, padding=1, bias=False, indice_key='spconv_down2'),
133+
norm_fn(256),
134+
nn.ReLU(),
135+
)
136+
137+
self.shared_conv = spconv.SparseSequential(
138+
spconv.SubMConv2d(256, 256, 3, stride=1, padding=1, bias=True),
139+
nn.BatchNorm1d(256),
140+
nn.ReLU(True),
141+
)
142+
143+
self.num_point_features = 256
144+
self.backbone_channels = {
145+
'x_conv1': 32,
146+
'x_conv2': 64,
147+
'x_conv3': 128,
148+
'x_conv4': 256,
149+
'x_conv5': 256
150+
}
151+
self.forward_ret_dict = {}
152+
153+
def bev_out(self, x_conv):
154+
features_cat = x_conv.features
155+
indices_cat = x_conv.indices
156+
157+
indices_unique, _inv = torch.unique(indices_cat, dim=0, return_inverse=True)
158+
features_unique = features_cat.new_zeros((indices_unique.shape[0], features_cat.shape[1]))
159+
features_unique.index_add_(0, _inv, features_cat)
160+
161+
x_out = spconv.SparseConvTensor(
162+
features=features_unique,
163+
indices=indices_unique,
164+
spatial_shape=x_conv.spatial_shape,
165+
batch_size=x_conv.batch_size
166+
)
167+
return x_out
168+
169+
def forward(self, batch_dict):
170+
pillar_features, pillar_coords = batch_dict['pillar_features'], batch_dict['pillar_coords']
171+
batch_size = batch_dict['batch_size']
172+
input_sp_tensor = spconv.SparseConvTensor(
173+
features=pillar_features,
174+
indices=pillar_coords.int(),
175+
spatial_shape=self.sparse_shape,
176+
batch_size=batch_size
177+
)
178+
179+
x_conv1 = self.conv1(input_sp_tensor)
180+
x_conv2 = self.conv2(x_conv1)
181+
x_conv3 = self.conv3(x_conv2)
182+
x_conv4 = self.conv4(x_conv3)
183+
x_conv5 = self.conv5(x_conv4)
184+
x_conv6 = self.conv6(x_conv5)
185+
186+
x_conv5.indices[:, 1:] *= 2
187+
x_conv6.indices[:, 1:] *= 4
188+
x_conv4 = x_conv4.replace_feature(torch.cat([x_conv4.features, x_conv5.features, x_conv6.features]))
189+
x_conv4.indices = torch.cat([x_conv4.indices, x_conv5.indices, x_conv6.indices])
190+
191+
out = self.bev_out(x_conv4)
192+
193+
out = self.conv_out(out)
194+
out = self.shared_conv(out)
195+
196+
batch_dict.update({
197+
'encoded_spconv_tensor': out,
198+
'encoded_spconv_tensor_stride': 8
199+
})
200+
batch_dict.update({
201+
'multi_scale_2d_features': {
202+
'x_conv1': x_conv1,
203+
'x_conv2': x_conv2,
204+
'x_conv3': x_conv3,
205+
'x_conv4': x_conv4,
206+
'x_conv5': x_conv5,
207+
}
208+
})
209+
batch_dict.update({
210+
'multi_scale_2d_strides': {
211+
'x_conv1': 1,
212+
'x_conv2': 2,
213+
'x_conv3': 4,
214+
'x_conv4': 8,
215+
'x_conv5': 16,
216+
}
217+
})
218+
219+
return batch_dict

0 commit comments

Comments
 (0)