Skip to content

Commit fa724b1

Browse files
authored
Add type hint for middle_encoder and voxel_encoder (#2556)
* 2023/05/26 add type hint * 2023/05/26 modify ugly typehint
1 parent 8e634dd commit fa724b1

File tree

7 files changed

+199
-139
lines changed

7 files changed

+199
-139
lines changed

mmdet3d/models/middle_encoders/pillar_scatter.py

Lines changed: 11 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,8 @@
11
# Copyright (c) OpenMMLab. All rights reserved.
2+
from typing import List
3+
24
import torch
3-
from torch import nn
5+
from torch import Tensor, nn
46

57
from mmdet3d.registry import MODELS
68

@@ -16,14 +18,17 @@ class PointPillarsScatter(nn.Module):
1618
output_shape (list[int]): Required output shape of features.
1719
"""
1820

19-
def __init__(self, in_channels, output_shape):
21+
def __init__(self, in_channels: int, output_shape: List[int]):
2022
super().__init__()
2123
self.output_shape = output_shape
2224
self.ny = output_shape[0]
2325
self.nx = output_shape[1]
2426
self.in_channels = in_channels
2527

26-
def forward(self, voxel_features, coors, batch_size=None):
28+
def forward(self,
29+
voxel_features: Tensor,
30+
coors: Tensor,
31+
batch_size: int = None) -> Tensor:
2732
"""Foraward function to scatter features."""
2833
# TODO: rewrite the function in a batch manner
2934
# no need to deal with different batch cases
@@ -32,7 +37,7 @@ def forward(self, voxel_features, coors, batch_size=None):
3237
else:
3338
return self.forward_single(voxel_features, coors)
3439

35-
def forward_single(self, voxel_features, coors):
40+
def forward_single(self, voxel_features: Tensor, coors: Tensor) -> Tensor:
3641
"""Scatter features of single sample.
3742
3843
Args:
@@ -56,7 +61,8 @@ def forward_single(self, voxel_features, coors):
5661
canvas = canvas.view(1, self.in_channels, self.ny, self.nx)
5762
return canvas
5863

59-
def forward_batch(self, voxel_features, coors, batch_size):
64+
def forward_batch(self, voxel_features: Tensor, coors: Tensor,
65+
batch_size: int) -> Tensor:
6066
"""Scatter features of single sample.
6167
6268
Args:

mmdet3d/models/middle_encoders/sparse_encoder.py

Lines changed: 48 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
# Copyright (c) OpenMMLab. All rights reserved.
2-
from typing import List, Tuple
2+
from typing import Dict, List, Optional, Tuple, Union
33

44
import torch
55
from mmcv.ops import points_in_boxes_all, three_interpolate, three_nn
@@ -18,6 +18,8 @@
1818
else:
1919
from mmcv.ops import SparseConvTensor, SparseSequential
2020

21+
TwoTupleIntType = Tuple[Tuple[int]]
22+
2123

2224
@MODELS.register_module()
2325
class SparseEncoder(nn.Module):
@@ -26,7 +28,7 @@ class SparseEncoder(nn.Module):
2628
Args:
2729
in_channels (int): The number of input channels.
2830
sparse_shape (list[int]): The sparse shape of input tensor.
29-
order (list[str], optional): Order of conv module.
31+
order (tuple[str], optional): Order of conv module.
3032
Defaults to ('conv', 'norm', 'act').
3133
norm_cfg (dict, optional): Config of normalization layer. Defaults to
3234
dict(type='BN1d', eps=1e-3, momentum=0.01).
@@ -46,19 +48,24 @@ class SparseEncoder(nn.Module):
4648
Default to False.
4749
"""
4850

49-
def __init__(self,
50-
in_channels,
51-
sparse_shape,
52-
order=('conv', 'norm', 'act'),
53-
norm_cfg=dict(type='BN1d', eps=1e-3, momentum=0.01),
54-
base_channels=16,
55-
output_channels=128,
56-
encoder_channels=((16, ), (32, 32, 32), (64, 64, 64), (64, 64,
57-
64)),
58-
encoder_paddings=((1, ), (1, 1, 1), (1, 1, 1), ((0, 1, 1), 1,
59-
1)),
60-
block_type='conv_module',
61-
return_middle_feats=False):
51+
def __init__(
52+
self,
53+
in_channels: int,
54+
sparse_shape: List[int],
55+
order: Optional[Tuple[str]] = ('conv', 'norm', 'act'),
56+
norm_cfg: Optional[dict] = dict(
57+
type='BN1d', eps=1e-3, momentum=0.01),
58+
base_channels: Optional[int] = 16,
59+
output_channels: Optional[int] = 128,
60+
encoder_channels: Optional[TwoTupleIntType] = ((16, ), (32, 32,
61+
32),
62+
(64, 64,
63+
64), (64, 64, 64)),
64+
encoder_paddings: Optional[TwoTupleIntType] = ((1, ), (1, 1, 1),
65+
(1, 1, 1),
66+
((0, 1, 1), 1, 1)),
67+
block_type: Optional[str] = 'conv_module',
68+
return_middle_feats: Optional[bool] = False):
6269
super().__init__()
6370
assert block_type in ['conv_module', 'basicblock']
6471
self.sparse_shape = sparse_shape
@@ -112,7 +119,8 @@ def __init__(self,
112119
conv_type='SparseConv3d')
113120

114121
@amp.autocast(enabled=False)
115-
def forward(self, voxel_features, coors, batch_size):
122+
def forward(self, voxel_features: Tensor, coors: Tensor,
123+
batch_size: int) -> Union[Tensor, Tuple[Tensor, list]]:
116124
"""Forward of SparseEncoder.
117125
118126
Args:
@@ -154,12 +162,14 @@ def forward(self, voxel_features, coors, batch_size):
154162
else:
155163
return spatial_features
156164

157-
def make_encoder_layers(self,
158-
make_block,
159-
norm_cfg,
160-
in_channels,
161-
block_type='conv_module',
162-
conv_cfg=dict(type='SubMConv3d')):
165+
def make_encoder_layers(
166+
self,
167+
make_block: nn.Module,
168+
norm_cfg: Dict,
169+
in_channels: int,
170+
block_type: Optional[str] = 'conv_module',
171+
conv_cfg: Optional[dict] = dict(type='SubMConv3d')
172+
) -> int:
163173
"""make encoder layers using sparse convs.
164174
165175
Args:
@@ -256,18 +266,22 @@ class SparseEncoderSASSD(SparseEncoder):
256266
Defaults to 'conv_module'.
257267
"""
258268

259-
def __init__(self,
260-
in_channels: int,
261-
sparse_shape: List[int],
262-
order: Tuple[str] = ('conv', 'norm', 'act'),
263-
norm_cfg: dict = dict(type='BN1d', eps=1e-3, momentum=0.01),
264-
base_channels: int = 16,
265-
output_channels: int = 128,
266-
encoder_channels: Tuple[tuple] = ((16, ), (32, 32, 32),
267-
(64, 64, 64), (64, 64, 64)),
268-
encoder_paddings: Tuple[tuple] = ((1, ), (1, 1, 1), (1, 1, 1),
269-
((0, 1, 1), 1, 1)),
270-
block_type: str = 'conv_module'):
269+
def __init__(
270+
self,
271+
in_channels: int,
272+
sparse_shape: List[int],
273+
order: Tuple[str] = ('conv', 'norm', 'act'),
274+
norm_cfg: dict = dict(type='BN1d', eps=1e-3, momentum=0.01),
275+
base_channels: int = 16,
276+
output_channels: int = 128,
277+
encoder_channels: Optional[TwoTupleIntType] = ((16, ), (32, 32,
278+
32),
279+
(64, 64,
280+
64), (64, 64, 64)),
281+
encoder_paddings: Optional[TwoTupleIntType] = ((1, ), (1, 1, 1),
282+
(1, 1, 1),
283+
((0, 1, 1), 1, 1)),
284+
block_type: str = 'conv_module'):
271285
super(SparseEncoderSASSD, self).__init__(
272286
in_channels=in_channels,
273287
sparse_shape=sparse_shape,

mmdet3d/models/middle_encoders/sparse_unet.py

Lines changed: 39 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,8 @@
11
# Copyright (c) OpenMMLab. All rights reserved.
2+
from typing import Dict, List, Optional, Tuple
3+
24
import torch
5+
from torch import Tensor, nn
36

47
from mmdet3d.models.layers.spconv import IS_SPCONV2_AVAILABLE
58

@@ -14,6 +17,8 @@
1417
from mmdet3d.models.layers.sparse_block import replace_feature
1518
from mmdet3d.registry import MODELS
1619

20+
TwoTupleIntType = Tuple[Tuple[int]]
21+
1722

1823
@MODELS.register_module()
1924
class SparseUNet(BaseModule):
@@ -35,21 +40,28 @@ class SparseUNet(BaseModule):
3540
decoder_paddings (tuple[tuple[int]]): Paddings of each decode block.
3641
"""
3742

38-
def __init__(self,
39-
in_channels,
40-
sparse_shape,
41-
order=('conv', 'norm', 'act'),
42-
norm_cfg=dict(type='BN1d', eps=1e-3, momentum=0.01),
43-
base_channels=16,
44-
output_channels=128,
45-
encoder_channels=((16, ), (32, 32, 32), (64, 64, 64), (64, 64,
46-
64)),
47-
encoder_paddings=((1, ), (1, 1, 1), (1, 1, 1), ((0, 1, 1), 1,
48-
1)),
49-
decoder_channels=((64, 64, 64), (64, 64, 32), (32, 32, 16),
50-
(16, 16, 16)),
51-
decoder_paddings=((1, 0), (1, 0), (0, 0), (0, 1)),
52-
init_cfg=None):
43+
def __init__(
44+
self,
45+
in_channels: int,
46+
sparse_shape: List[int],
47+
order: Tuple[str] = ('conv', 'norm', 'act'),
48+
norm_cfg: dict = dict(type='BN1d', eps=1e-3, momentum=0.01),
49+
base_channels: int = 16,
50+
output_channels: int = 128,
51+
encoder_channels: Optional[TwoTupleIntType] = ((16, ), (32, 32,
52+
32),
53+
(64, 64,
54+
64), (64, 64, 64)),
55+
encoder_paddings: Optional[TwoTupleIntType] = ((1, ), (1, 1, 1),
56+
(1, 1, 1),
57+
((0, 1, 1), 1, 1)),
58+
decoder_channels: Optional[TwoTupleIntType] = ((64, 64,
59+
64), (64, 64, 32),
60+
(32, 32,
61+
16), (16, 16, 16)),
62+
decoder_paddings: Optional[TwoTupleIntType] = ((1, 0), (1, 0),
63+
(0, 0), (0, 1)),
64+
init_cfg: bool = None):
5365
super().__init__(init_cfg=init_cfg)
5466
self.sparse_shape = sparse_shape
5567
self.in_channels = in_channels
@@ -101,7 +113,8 @@ def __init__(self,
101113
indice_key='spconv_down2',
102114
conv_type='SparseConv3d')
103115

104-
def forward(self, voxel_features, coors, batch_size):
116+
def forward(self, voxel_features: Tensor, coors: Tensor,
117+
batch_size: int) -> Dict[str, Tensor]:
105118
"""Forward of SparseUNet.
106119
107120
Args:
@@ -152,8 +165,10 @@ def forward(self, voxel_features, coors, batch_size):
152165

153166
return ret
154167

155-
def decoder_layer_forward(self, x_lateral, x_bottom, lateral_layer,
156-
merge_layer, upsample_layer):
168+
def decoder_layer_forward(
169+
self, x_lateral: SparseConvTensor, x_bottom: SparseConvTensor,
170+
lateral_layer: SparseBasicBlock, merge_layer: SparseSequential,
171+
upsample_layer: SparseSequential) -> SparseConvTensor:
157172
"""Forward of upsample and residual block.
158173
159174
Args:
@@ -176,7 +191,8 @@ def decoder_layer_forward(self, x_lateral, x_bottom, lateral_layer,
176191
return x
177192

178193
@staticmethod
179-
def reduce_channel(x, out_channels):
194+
def reduce_channel(x: SparseConvTensor,
195+
out_channels: int) -> SparseConvTensor:
180196
"""reduce channel for element-wise addition.
181197
182198
Args:
@@ -194,7 +210,8 @@ def reduce_channel(x, out_channels):
194210
x = replace_feature(x, features.view(n, out_channels, -1).sum(dim=2))
195211
return x
196212

197-
def make_encoder_layers(self, make_block, norm_cfg, in_channels):
213+
def make_encoder_layers(self, make_block: nn.Module, norm_cfg: dict,
214+
in_channels: int) -> int:
198215
"""make encoder layers using sparse convs.
199216
200217
Args:
@@ -240,7 +257,8 @@ def make_encoder_layers(self, make_block, norm_cfg, in_channels):
240257
self.encoder_layers.add_module(stage_name, stage_layers)
241258
return out_channels
242259

243-
def make_decoder_layers(self, make_block, norm_cfg, in_channels):
260+
def make_decoder_layers(self, make_block: nn.Module, norm_cfg: dict,
261+
in_channels: int) -> int:
244262
"""make decoder layers using sparse convs.
245263
246264
Args:

mmdet3d/models/middle_encoders/voxel_set_abstraction.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,12 +7,13 @@
77
from mmcv.cnn import ConvModule
88
from mmcv.ops.furthest_point_sample import furthest_point_sample
99
from mmengine.model import BaseModule
10+
from torch import Tensor
1011

1112
from mmdet3d.registry import MODELS
1213
from mmdet3d.utils import InstanceList
1314

1415

15-
def bilinear_interpolate_torch(inputs, x, y):
16+
def bilinear_interpolate_torch(inputs: Tensor, x: Tensor, y: Tensor) -> Tensor:
1617
"""Bilinear interpolate for inputs."""
1718
x0 = torch.floor(x).long()
1819
x1 = x0 + 1

0 commit comments

Comments
 (0)