Skip to content

Commit 98d2642

Browse files
[Feature] Spvcnn backbone (#2320)
* add cylindrical voxelization & voxel feature encoder * add cylindrical voxelization & voxel feature encoder * add voxel-wise label & voxelization UT * fix vfe * fix vfe UT * rename voxel encoder & add more test case * fix type hint * temporarily refactoring mmcv's voxelize and dynamic in mmdet3d for data_preprocesser * _forward * del checkpoints * add if tp * add predict * fix vfe init bug & fix UT * add grid_size & move voxelization code * fix import bug * keep radian to follow origin * add doc string * fix type hint * add minkunet voxelization and loss function * fix data * init train * fix sparsetensor typehint * rename dir * fix data config * fix data config * fix batch_size & replace dynamic_scatter * fix conflicts 2 * fix conflicts on s_70 * Alignment of the original implementation * rename config * add worker_init_fn_hook * remove test_config & worker hook * add UT * fix polarmix UT * init spcvnn backbone * add seed for cr0p5 * spvcnn_init * format * rename SemanticKittiDataset * add platte & fix visual bug * add platte & fix data info bug * fix ut * fix ut * fix semantic_kitti ut * train init * fix docstring * fix config name * rename layer * fix doc string * fix review * remove filter data * rename config * rename backbone * rename backbone 2 * refactor voxel2point * fix coors typo * fix ut * fix ut * pred in segmentor * fix get voxel seg * resolve comments * rename p2v and v2p * rename points and voxels
1 parent f4b0174 commit 98d2642

File tree

7 files changed

+374
-1
lines changed

7 files changed

+374
-1
lines changed

configs/_base_/models/spvcnn.py

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,29 @@
1+
model = dict(
2+
type='MinkUNet',
3+
data_preprocessor=dict(
4+
type='Det3DDataPreprocessor',
5+
voxel=True,
6+
voxel_type='minkunet',
7+
voxel_layer=dict(
8+
max_num_points=-1,
9+
point_cloud_range=[-100, -100, -20, 100, 100, 20],
10+
voxel_size=[0.05, 0.05, 0.05],
11+
max_voxels=(-1, -1)),
12+
),
13+
backbone=dict(
14+
type='SPVCNNBackbone',
15+
in_channels=4,
16+
base_channels=32,
17+
encoder_channels=[32, 64, 128, 256],
18+
decoder_channels=[256, 128, 96, 96],
19+
num_stages=4,
20+
drop_ratio=0.3),
21+
decode_head=dict(
22+
type='MinkUNetHead',
23+
channels=96,
24+
num_classes=19,
25+
dropout_ratio=0,
26+
loss_decode=dict(type='mmdet.CrossEntropyLoss', avg_non_ignore=True),
27+
ignore_index=19),
28+
train_cfg=dict(),
29+
test_cfg=dict())
Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,10 @@
1+
_base_ = ['./spvcnn_w32_8xb2-15e_semantickitti.py']
2+
3+
model = dict(
4+
backbone=dict(
5+
base_channels=16,
6+
encoder_channels=[16, 32, 64, 128],
7+
decoder_channels=[128, 64, 48, 48]),
8+
decode_head=dict(channels=48))
9+
10+
randomness = dict(seed=1588147245)
Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,8 @@
1+
_base_ = ['./spvcnn_w32_8xb2-15e_semantickitti.py']
2+
3+
model = dict(
4+
backbone=dict(
5+
base_channels=20,
6+
encoder_channels=[20, 40, 81, 163],
7+
decoder_channels=[163, 81, 61, 61]),
8+
decode_head=dict(channels=61))
Lines changed: 54 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,54 @@
1+
_base_ = [
2+
'../_base_/datasets/semantickitti.py', '../_base_/models/spvcnn.py',
3+
'../_base_/default_runtime.py'
4+
]
5+
6+
train_pipeline = [
7+
dict(type='LoadPointsFromFile', coord_type='LIDAR', load_dim=4, use_dim=4),
8+
dict(
9+
type='LoadAnnotations3D',
10+
with_bbox_3d=False,
11+
with_label_3d=False,
12+
with_seg_3d=True,
13+
seg_3d_dtype='np.int32',
14+
seg_offset=2**16,
15+
dataset_type='semantickitti'),
16+
dict(type='PointSegClassMapping'),
17+
dict(
18+
type='GlobalRotScaleTrans',
19+
rot_range=[0., 6.28318531],
20+
scale_ratio_range=[0.95, 1.05],
21+
translation_std=[0, 0, 0],
22+
),
23+
dict(type='Pack3DDetInputs', keys=['points', 'pts_semantic_mask'])
24+
]
25+
26+
train_dataloader = dict(
27+
sampler=dict(seed=0), dataset=dict(dataset=dict(pipeline=train_pipeline)))
28+
29+
lr = 0.24
30+
optim_wrapper = dict(
31+
type='AmpOptimWrapper',
32+
loss_scale='dynamic',
33+
optimizer=dict(
34+
type='SGD', lr=lr, weight_decay=0.0001, momentum=0.9, nesterov=True))
35+
36+
param_scheduler = [
37+
dict(
38+
type='LinearLR', start_factor=0.008, by_epoch=False, begin=0, end=125),
39+
dict(
40+
type='CosineAnnealingLR',
41+
begin=0,
42+
T_max=15,
43+
by_epoch=True,
44+
eta_min=1e-5,
45+
convert_to_iter_based=True)
46+
]
47+
48+
train_cfg = dict(type='EpochBasedTrainLoop', max_epochs=15, val_interval=1)
49+
val_cfg = dict(type='ValLoop')
50+
test_cfg = dict(type='TestLoop')
51+
52+
default_hooks = dict(checkpoint=dict(type='CheckpointHook', interval=1))
53+
randomness = dict(seed=0, deterministic=False, diff_rank_seed=True)
54+
env_cfg = dict(cudnn_benchmark=True)

mmdet3d/models/backbones/__init__.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,10 +11,11 @@
1111
from .pointnet2_sa_msg import PointNet2SAMSG
1212
from .pointnet2_sa_ssg import PointNet2SASSG
1313
from .second import SECOND
14+
from .spvcnn_backone import SPVCNNBackbone
1415

1516
__all__ = [
1617
'ResNet', 'ResNetV1d', 'ResNeXt', 'SSDVGG', 'HRNet', 'NoStemRegNet',
1718
'SECOND', 'DGCNNBackbone', 'PointNet2SASSG', 'PointNet2SAMSG',
1819
'MultiBackbone', 'DLANet', 'MinkResNet', 'Asymm3DSpconv',
19-
'MinkUNetBackbone'
20+
'MinkUNetBackbone', 'SPVCNNBackbone'
2021
]
Lines changed: 237 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,237 @@
1+
# Copyright (c) OpenMMLab. All rights reserved.
2+
from typing import Optional, Sequence
3+
4+
import torch
5+
from mmengine.registry import MODELS
6+
from torch import Tensor, nn
7+
8+
from mmdet3d.models.layers.torchsparse import IS_TORCHSPARSE_AVAILABLE
9+
from mmdet3d.utils import OptMultiConfig
10+
from .minkunet_backbone import MinkUNetBackbone
11+
12+
if IS_TORCHSPARSE_AVAILABLE:
13+
import torchsparse
14+
import torchsparse.nn.functional as F
15+
from torchsparse.nn.utils import get_kernel_offsets
16+
from torchsparse.tensor import PointTensor, SparseTensor
17+
else:
18+
PointTensor = SparseTensor = None
19+
20+
21+
@MODELS.register_module()
22+
class SPVCNNBackbone(MinkUNetBackbone):
23+
"""SPVCNN backbone with torchsparse backend.
24+
25+
More details can be found in `paper <https://arxiv.org/abs/2007.16100>`_ .
26+
27+
Args:
28+
in_channels (int): Number of input voxel feature channels.
29+
Defaults to 4.
30+
base_channels (int): The input channels for first encoder layer.
31+
Defaults to 32.
32+
encoder_channels (List[int]): Convolutional channels of each encode
33+
layer. Defaults to [32, 64, 128, 256].
34+
decoder_channels (List[int]): Convolutional channels of each decode
35+
layer. Defaults to [256, 128, 96, 96].
36+
num_stages (int): Number of stages in encoder and decoder.
37+
Defaults to 4.
38+
drop_ratio (float): Dropout ratio of voxel features. Defaults to 0.3.
39+
init_cfg (dict or :obj:`ConfigDict` or list[dict or :obj:`ConfigDict`]
40+
, optional): Initialization config dict. Defaults to None.
41+
"""
42+
43+
def __init__(self,
44+
in_channels: int = 4,
45+
base_channels: int = 32,
46+
encoder_channels: Sequence[int] = [32, 64, 128, 256],
47+
decoder_channels: Sequence[int] = [256, 128, 96, 96],
48+
num_stages: int = 4,
49+
drop_ratio: float = 0.3,
50+
init_cfg: OptMultiConfig = None) -> None:
51+
super().__init__(
52+
in_channels=in_channels,
53+
base_channels=base_channels,
54+
encoder_channels=encoder_channels,
55+
decoder_channels=decoder_channels,
56+
num_stages=num_stages,
57+
init_cfg=init_cfg)
58+
59+
self.point_transforms = nn.ModuleList([
60+
nn.Sequential(
61+
nn.Linear(base_channels, encoder_channels[-1]),
62+
nn.BatchNorm1d(encoder_channels[-1]), nn.ReLU(True)),
63+
nn.Sequential(
64+
nn.Linear(encoder_channels[-1], decoder_channels[2]),
65+
nn.BatchNorm1d(decoder_channels[2]), nn.ReLU(True)),
66+
nn.Sequential(
67+
nn.Linear(decoder_channels[2], decoder_channels[4]),
68+
nn.BatchNorm1d(decoder_channels[4]), nn.ReLU(True))
69+
])
70+
self.dropout = nn.Dropout(drop_ratio, True)
71+
72+
def forward(self, voxel_features: Tensor, coors: Tensor) -> PointTensor:
73+
"""Forward function.
74+
75+
Args:
76+
voxel_features (Tensor): Voxel features in shape (N, C).
77+
coors (Tensor): Coordinates in shape (N, 4),
78+
the columns in the order of (x_idx, y_idx, z_idx, batch_idx).
79+
80+
Returns:
81+
PointTensor: Backbone features.
82+
"""
83+
voxels = SparseTensor(voxel_features, coors)
84+
points = PointTensor(voxels.F, voxels.C.float())
85+
voxels = self.initial_voxelize(points)
86+
87+
voxels = self.conv_input(voxels)
88+
points = self.voxel_to_point(voxels, points)
89+
voxels = self.point_to_voxel(voxels, points)
90+
laterals = [voxels]
91+
for encoder in self.encoder:
92+
voxels = encoder(voxels)
93+
laterals.append(voxels)
94+
laterals = laterals[:-1][::-1]
95+
96+
points = self.voxel_to_point(voxels, points, self.point_transforms[0])
97+
voxels = self.point_to_voxel(voxels, points)
98+
voxels.F = self.dropout(voxels.F)
99+
100+
decoder_outs = []
101+
for i, decoder in enumerate(self.decoder):
102+
voxels = decoder[0](voxels)
103+
voxels = torchsparse.cat((voxels, laterals[i]))
104+
voxels = decoder[1](voxels)
105+
decoder_outs.append(voxels)
106+
if i == 1:
107+
points = self.voxel_to_point(voxels, points,
108+
self.point_transforms[1])
109+
voxels = self.point_to_voxel(voxels, points)
110+
voxels.F = self.dropout(voxels.F)
111+
112+
points = self.voxel_to_point(voxels, points, self.point_transforms[2])
113+
return points
114+
115+
def initial_voxelize(self, points: PointTensor) -> SparseTensor:
116+
"""Voxelization again based on input PointTensor.
117+
118+
Args:
119+
points (PointTensor): Input points after voxelization.
120+
121+
Returns:
122+
SparseTensor: New voxels.
123+
"""
124+
pc_hash = F.sphash(torch.floor(points.C).int())
125+
sparse_hash = torch.unique(pc_hash)
126+
idx_query = F.sphashquery(pc_hash, sparse_hash)
127+
counts = F.spcount(idx_query.int(), len(sparse_hash))
128+
129+
inserted_coords = F.spvoxelize(
130+
torch.floor(points.C), idx_query, counts)
131+
inserted_coords = torch.round(inserted_coords).int()
132+
inserted_feat = F.spvoxelize(points.F, idx_query, counts)
133+
134+
new_tensor = SparseTensor(inserted_feat, inserted_coords, 1)
135+
new_tensor.cmaps.setdefault(new_tensor.stride, new_tensor.coords)
136+
points.additional_features['idx_query'][1] = idx_query
137+
points.additional_features['counts'][1] = counts
138+
return new_tensor
139+
140+
def voxel_to_point(self,
141+
voxels: SparseTensor,
142+
points: PointTensor,
143+
point_transform: Optional[nn.Module] = None,
144+
nearest: bool = False) -> PointTensor:
145+
"""Feed voxel features to points.
146+
147+
Args:
148+
voxels (SparseTensor): Input voxels.
149+
points (PointTensor): Input points.
150+
point_transform (nn.Module, optional): Point transform module
151+
for input point features. Defaults to None.
152+
nearest (bool): Whether to use nearest neighbor interpolation.
153+
Defaults to False.
154+
155+
Returns:
156+
PointTensor: Points with new features.
157+
"""
158+
if points.idx_query is None or points.weights is None or \
159+
points.idx_query.get(voxels.s) is None or \
160+
points.weights.get(voxels.s) is None:
161+
offsets = get_kernel_offsets(
162+
2, voxels.s, 1, device=points.F.device)
163+
old_hash = F.sphash(
164+
torch.cat([
165+
torch.floor(points.C[:, :3] / voxels.s[0]).int() *
166+
voxels.s[0], points.C[:, -1].int().view(-1, 1)
167+
], 1), offsets)
168+
pc_hash = F.sphash(voxels.C.to(points.F.device))
169+
idx_query = F.sphashquery(old_hash, pc_hash)
170+
weights = F.calc_ti_weights(
171+
points.C, idx_query,
172+
scale=voxels.s[0]).transpose(0, 1).contiguous()
173+
idx_query = idx_query.transpose(0, 1).contiguous()
174+
if nearest:
175+
weights[:, 1:] = 0.
176+
idx_query[:, 1:] = -1
177+
new_features = F.spdevoxelize(voxels.F, idx_query, weights)
178+
new_tensor = PointTensor(
179+
new_features,
180+
points.C,
181+
idx_query=points.idx_query,
182+
weights=points.weights)
183+
new_tensor.additional_features = points.additional_features
184+
new_tensor.idx_query[voxels.s] = idx_query
185+
new_tensor.weights[voxels.s] = weights
186+
points.idx_query[voxels.s] = idx_query
187+
points.weights[voxels.s] = weights
188+
else:
189+
new_features = F.spdevoxelize(voxels.F,
190+
points.idx_query.get(voxels.s),
191+
points.weights.get(voxels.s))
192+
new_tensor = PointTensor(
193+
new_features,
194+
points.C,
195+
idx_query=points.idx_query,
196+
weights=points.weights)
197+
new_tensor.additional_features = points.additional_features
198+
199+
if point_transform is not None:
200+
new_tensor.F = new_tensor.F + point_transform(points.F)
201+
202+
return new_tensor
203+
204+
def point_to_voxel(self, voxels: SparseTensor,
205+
points: PointTensor) -> SparseTensor:
206+
"""Feed point features to voxels.
207+
208+
Args:
209+
voxels (SparseTensor): Input voxels.
210+
points (PointTensor): Input points.
211+
212+
Returns:
213+
SparseTensor: Voxels with new features.
214+
"""
215+
if points.additional_features is None or \
216+
points.additional_features.get('idx_query') is None or \
217+
points.additional_features['idx_query'].get(voxels.s) is None:
218+
pc_hash = F.sphash(
219+
torch.cat([
220+
torch.floor(points.C[:, :3] / voxels.s[0]).int() *
221+
voxels.s[0], points.C[:, -1].int().view(-1, 1)
222+
], 1))
223+
sparse_hash = F.sphash(voxels.C)
224+
idx_query = F.sphashquery(pc_hash, sparse_hash)
225+
counts = F.spcount(idx_query.int(), voxels.C.shape[0])
226+
points.additional_features['idx_query'][voxels.s] = idx_query
227+
points.additional_features['counts'][voxels.s] = counts
228+
else:
229+
idx_query = points.additional_features['idx_query'][voxels.s]
230+
counts = points.additional_features['counts'][voxels.s]
231+
232+
inserted_features = F.spvoxelize(points.F, idx_query, counts)
233+
new_tensor = SparseTensor(inserted_features, voxels.C, voxels.s)
234+
new_tensor.cmaps = voxels.cmaps
235+
new_tensor.kmaps = voxels.kmaps
236+
237+
return new_tensor
Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,34 @@
1+
# Copyright (c) OpenMMLab. All rights reserved.
2+
import pytest
3+
import torch
4+
import torch.nn.functional as F
5+
6+
from mmdet3d.registry import MODELS
7+
8+
9+
def test_spvcnn_backbone():
10+
if not torch.cuda.is_available():
11+
pytest.skip('test requires GPU and torch+cuda')
12+
13+
try:
14+
import torchsparse # noqa: F401
15+
except ImportError:
16+
pytest.skip('test requires Torchsparse installation')
17+
18+
coordinates, features = [], []
19+
for i in range(2):
20+
c = torch.randint(0, 10, (100, 3)).int()
21+
c = F.pad(c, (0, 1), mode='constant', value=i)
22+
coordinates.append(c)
23+
f = torch.rand(100, 4)
24+
features.append(f)
25+
features = torch.cat(features, dim=0).cuda()
26+
coordinates = torch.cat(coordinates, dim=0).cuda()
27+
28+
cfg = dict(type='SPVCNNBackbone')
29+
self = MODELS.build(cfg).cuda()
30+
self.init_weights()
31+
32+
y = self(features, coordinates)
33+
assert y.F.shape == torch.Size([200, 96])
34+
assert y.C.shape == torch.Size([200, 4])

0 commit comments

Comments
 (0)