Skip to content

Commit b5a706d

Browse files
[Feature] Add type hint for code in models/task_modules (#2485)
* 2023/04/28 add task_modules type hint * Update mmdet3d/models/task_modules/anchor/anchor_3d_generator.py Co-authored-by: Xiang Xu <[email protected]> * Update mmdet3d/models/task_modules/anchor/anchor_3d_generator.py Co-authored-by: Xiang Xu <[email protected]> * Update mmdet3d/models/task_modules/anchor/anchor_3d_generator.py Co-authored-by: Xiang Xu <[email protected]> * Update mmdet3d/models/task_modules/anchor/anchor_3d_generator.py Co-authored-by: Xiang Xu <[email protected]> * Update mmdet3d/models/task_modules/anchor/anchor_3d_generator.py Co-authored-by: Xiang Xu <[email protected]> * Update mmdet3d/models/task_modules/anchor/anchor_3d_generator.py Co-authored-by: Xiang Xu <[email protected]> * Update mmdet3d/models/task_modules/anchor/anchor_3d_generator.py Co-authored-by: Xiang Xu <[email protected]> * Update mmdet3d/models/task_modules/anchor/anchor_3d_generator.py Co-authored-by: Xiang Xu <[email protected]> * Update mmdet3d/models/task_modules/anchor/anchor_3d_generator.py Co-authored-by: Xiang Xu <[email protected]> * Update mmdet3d/models/task_modules/anchor/anchor_3d_generator.py Co-authored-by: Xiang Xu <[email protected]> * Update mmdet3d/models/task_modules/coders/groupfree3d_bbox_coder.py Co-authored-by: Xiang Xu <[email protected]> * Update mmdet3d/models/task_modules/anchor/anchor_3d_generator.py Co-authored-by: Xiang Xu <[email protected]> * Update mmdet3d/models/task_modules/anchor/anchor_3d_generator.py Co-authored-by: Xiang Xu <[email protected]> * 2023/05/19 fix_wrong_hint * Update centerpoint_bbox_coders.py * Update iou_neg_piecewise_sampler.py --------- Co-authored-by: Xiang Xu <[email protected]>
1 parent 5734aef commit b5a706d

17 files changed

+318
-200
lines changed

mmdet3d/models/task_modules/anchor/anchor_3d_generator.py

Lines changed: 50 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,9 @@
11
# Copyright (c) OpenMMLab. All rights reserved.
2+
from typing import List, Tuple, Union
3+
24
import mmengine
35
import torch
6+
from torch import Tensor
47

58
from mmdet3d.registry import TASK_UTILS
69

@@ -37,13 +40,13 @@ class Anchor3DRangeGenerator(object):
3740
"""
3841

3942
def __init__(self,
40-
ranges,
41-
sizes=[[3.9, 1.6, 1.56]],
42-
scales=[1],
43-
rotations=[0, 1.5707963],
44-
custom_values=(),
45-
reshape_out=True,
46-
size_per_range=True):
43+
ranges: List[List[float]],
44+
sizes: List[List[float]] = [[3.9, 1.6, 1.56]],
45+
scales: List[int] = [1],
46+
rotations: List[float] = [0, 1.5707963],
47+
custom_values: Tuple[float] = (),
48+
reshape_out: bool = True,
49+
size_per_range: bool = True) -> None:
4750
assert mmengine.is_list_of(ranges, list)
4851
if size_per_range:
4952
if len(sizes) != len(ranges):
@@ -64,7 +67,7 @@ def __init__(self,
6467
self.reshape_out = reshape_out
6568
self.size_per_range = size_per_range
6669

67-
def __repr__(self):
70+
def __repr__(self) -> str:
6871
s = self.__class__.__name__ + '('
6972
s += f'anchor_range={self.ranges},\n'
7073
s += f'scales={self.scales},\n'
@@ -75,18 +78,21 @@ def __repr__(self):
7578
return s
7679

7780
@property
78-
def num_base_anchors(self):
79-
"""list[int]: Total number of base anchors in a feature grid."""
81+
def num_base_anchors(self) -> int:
82+
"""int: Total number of base anchors in a feature grid."""
8083
num_rot = len(self.rotations)
8184
num_size = torch.tensor(self.sizes).reshape(-1, 3).size(0)
8285
return num_rot * num_size
8386

8487
@property
85-
def num_levels(self):
88+
def num_levels(self) -> int:
8689
"""int: Number of feature levels that the generator is applied to."""
8790
return len(self.scales)
8891

89-
def grid_anchors(self, featmap_sizes, device='cuda'):
92+
def grid_anchors(
93+
self,
94+
featmap_sizes: List[Tuple[int]],
95+
device: Union[str, torch.device] = 'cuda') -> List[Tensor]:
9096
"""Generate grid anchors in multiple feature levels.
9197
9298
Args:
@@ -112,7 +118,11 @@ def grid_anchors(self, featmap_sizes, device='cuda'):
112118
multi_level_anchors.append(anchors)
113119
return multi_level_anchors
114120

115-
def single_level_grid_anchors(self, featmap_size, scale, device='cuda'):
121+
def single_level_grid_anchors(
122+
self,
123+
featmap_size: Tuple[int],
124+
scale: int,
125+
device: Union[str, torch.device] = 'cuda') -> Tensor:
116126
"""Generate grid anchors of a single level feature map.
117127
118128
This function is usually called by method ``self.grid_anchors``.
@@ -152,13 +162,14 @@ def single_level_grid_anchors(self, featmap_size, scale, device='cuda'):
152162
mr_anchors = torch.cat(mr_anchors, dim=-3)
153163
return mr_anchors
154164

155-
def anchors_single_range(self,
156-
feature_size,
157-
anchor_range,
158-
scale=1,
159-
sizes=[[3.9, 1.6, 1.56]],
160-
rotations=[0, 1.5707963],
161-
device='cuda'):
165+
def anchors_single_range(
166+
self,
167+
feature_size: Tuple[int],
168+
anchor_range: Union[Tensor, List[float]],
169+
scale: int = 1,
170+
sizes: Union[List[List[float]], List[float]] = [[3.9, 1.6, 1.56]],
171+
rotations: List[float] = [0, 1.5707963],
172+
device: Union[str, torch.device] = 'cuda') -> Tensor:
162173
"""Generate anchors in a single range.
163174
164175
Args:
@@ -248,17 +259,18 @@ class AlignedAnchor3DRangeGenerator(Anchor3DRangeGenerator):
248259
center of the corresponding greature grid. Defaults to False.
249260
"""
250261

251-
def __init__(self, align_corner=False, **kwargs):
262+
def __init__(self, align_corner: bool = False, **kwargs) -> None:
252263
super(AlignedAnchor3DRangeGenerator, self).__init__(**kwargs)
253264
self.align_corner = align_corner
254265

255-
def anchors_single_range(self,
256-
feature_size,
257-
anchor_range,
258-
scale,
259-
sizes=[[3.9, 1.6, 1.56]],
260-
rotations=[0, 1.5707963],
261-
device='cuda'):
266+
def anchors_single_range(
267+
self,
268+
feature_size: List[int],
269+
anchor_range: List[float],
270+
scale: int,
271+
sizes: Union[List[List[float]], List[float]] = [[3.9, 1.6, 1.56]],
272+
rotations: List[float] = [0, 1.5707963],
273+
device: Union[str, torch.device] = 'cuda') -> Tensor:
262274
"""Generate anchors in a single range.
263275
264276
Args:
@@ -352,12 +364,15 @@ class AlignedAnchor3DRangeGeneratorPerCls(AlignedAnchor3DRangeGenerator):
352364
:class:`AlignedAnchor3DRangeGenerator`.
353365
"""
354366

355-
def __init__(self, **kwargs):
367+
def __init__(self, **kwargs) -> None:
356368
super(AlignedAnchor3DRangeGeneratorPerCls, self).__init__(**kwargs)
357369
assert len(self.scales) == 1, 'Multi-scale feature map levels are' + \
358370
' not supported currently in this kind of anchor generator.'
359371

360-
def grid_anchors(self, featmap_sizes, device='cuda'):
372+
def grid_anchors(
373+
self,
374+
featmap_sizes: List[Tuple[int]],
375+
device: Union[str, torch.device] = 'cuda') -> List[List[Tensor]]:
361376
"""Generate grid anchors in multiple feature levels.
362377
363378
Args:
@@ -379,7 +394,11 @@ def grid_anchors(self, featmap_sizes, device='cuda'):
379394
multi_level_anchors.append(anchors)
380395
return multi_level_anchors
381396

382-
def multi_cls_grid_anchors(self, featmap_sizes, scale, device='cuda'):
397+
def multi_cls_grid_anchors(
398+
self,
399+
featmap_sizes: List[Tuple[int]],
400+
scale: int,
401+
device: Union[str, torch.device] = 'cuda') -> List[Tensor]:
383402
"""Generate grid anchors of a single level feature map for multi-class
384403
with different feature map sizes.
385404

mmdet3d/models/task_modules/anchor/builder.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,21 +1,23 @@
11
# Copyright (c) OpenMMLab. All rights reserved.
22
import warnings
3+
from typing import Any
34

45
from mmdet3d.registry import TASK_UTILS
6+
from mmdet3d.utils import ConfigType
57

68
PRIOR_GENERATORS = TASK_UTILS
79

810
ANCHOR_GENERATORS = TASK_UTILS
911

1012

11-
def build_prior_generator(cfg, default_args=None):
13+
def build_prior_generator(cfg: ConfigType, default_args=None) -> Any:
1214
warnings.warn(
1315
'``build_prior_generator`` would be deprecated soon, please use '
1416
'``mmdet3d.registry.TASK_UTILS.build()`` ')
1517
return TASK_UTILS.build(cfg, default_args=default_args)
1618

1719

18-
def build_anchor_generator(cfg, default_args=None):
20+
def build_anchor_generator(cfg: ConfigType, default_args=None) -> Any:
1921
warnings.warn(
2022
'``build_anchor_generator`` would be deprecated soon, please use '
2123
'``mmdet3d.registry.TASK_UTILS.build()`` ')

mmdet3d/models/task_modules/assigners/max_3d_iou_assigner.py

Lines changed: 12 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -42,16 +42,18 @@ class Max3DIoUAssigner(MaxIoUAssigner):
4242
iou_calculator (dict): Config of overlaps Calculator.
4343
"""
4444

45-
def __init__(self,
46-
pos_iou_thr: float,
47-
neg_iou_thr: Union[float, tuple],
48-
min_pos_iou: float = .0,
49-
gt_max_assign_all: bool = True,
50-
ignore_iof_thr: float = -1,
51-
ignore_wrt_candidates: bool = True,
52-
match_low_quality: bool = True,
53-
gpu_assign_thr: float = -1,
54-
iou_calculator: dict = dict(type='BboxOverlaps2D')):
45+
def __init__(
46+
self,
47+
pos_iou_thr: float,
48+
neg_iou_thr: Union[float, tuple],
49+
min_pos_iou: float = .0,
50+
gt_max_assign_all: bool = True,
51+
ignore_iof_thr: float = -1,
52+
ignore_wrt_candidates: bool = True,
53+
match_low_quality: bool = True,
54+
gpu_assign_thr: float = -1,
55+
iou_calculator: dict = dict(type='BboxOverlaps2D')
56+
) -> None:
5557
self.pos_iou_thr = pos_iou_thr
5658
self.neg_iou_thr = neg_iou_thr
5759
self.min_pos_iou = min_pos_iou

mmdet3d/models/task_modules/builder.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,28 +1,30 @@
11
# Copyright (c) OpenMMLab. All rights reserved.
22
import warnings
3+
from typing import Any
34

45
from mmdet3d.registry import TASK_UTILS
6+
from mmdet3d.utils.typing_utils import ConfigType
57

68
BBOX_ASSIGNERS = TASK_UTILS
79
BBOX_SAMPLERS = TASK_UTILS
810
BBOX_CODERS = TASK_UTILS
911

1012

11-
def build_assigner(cfg, **default_args):
13+
def build_assigner(cfg: ConfigType, **default_args) -> Any:
1214
"""Builder of box assigner."""
1315
warnings.warn('``build_assigner`` would be deprecated soon, please use '
1416
'``mmdet3d.registry.TASK_UTILS.build()`` ')
1517
return TASK_UTILS.build(cfg, default_args=default_args)
1618

1719

18-
def build_sampler(cfg, **default_args):
20+
def build_sampler(cfg: ConfigType, **default_args) -> Any:
1921
"""Builder of box sampler."""
2022
warnings.warn('``build_sampler`` would be deprecated soon, please use '
2123
'``mmdet3d.registry.TASK_UTILS.build()`` ')
2224
return TASK_UTILS.build(cfg, default_args=default_args)
2325

2426

25-
def build_bbox_coder(cfg, **default_args):
27+
def build_bbox_coder(cfg: ConfigType, **default_args) -> Any:
2628
"""Builder of box coder."""
2729
warnings.warn('``build_bbox_coder`` would be deprecated soon, please use '
2830
'``mmdet3d.registry.TASK_UTILS.build()`` ')

mmdet3d/models/task_modules/coders/anchor_free_bbox_coder.py

Lines changed: 10 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,12 @@
11
# Copyright (c) OpenMMLab. All rights reserved.
2+
from typing import Dict
3+
24
import numpy as np
35
import torch
6+
from torch import Tensor
47

58
from mmdet3d.registry import TASK_UTILS
9+
from mmdet3d.structures import BaseInstance3DBoxes
610
from .partial_bin_based_bbox_coder import PartialBinBasedBBoxCoder
711

812

@@ -15,13 +19,14 @@ class AnchorFreeBBoxCoder(PartialBinBasedBBoxCoder):
1519
with_rot (bool): Whether the bbox is with rotation.
1620
"""
1721

18-
def __init__(self, num_dir_bins, with_rot=True):
22+
def __init__(self, num_dir_bins: int, with_rot: bool = True) -> None:
1923
super(AnchorFreeBBoxCoder, self).__init__(
2024
num_dir_bins, 0, [], with_rot=with_rot)
2125
self.num_dir_bins = num_dir_bins
2226
self.with_rot = with_rot
2327

24-
def encode(self, gt_bboxes_3d, gt_labels_3d):
28+
def encode(self, gt_bboxes_3d: BaseInstance3DBoxes,
29+
gt_labels_3d: Tensor) -> tuple:
2530
"""Encode ground truth to prediction targets.
2631
2732
Args:
@@ -51,7 +56,7 @@ def encode(self, gt_bboxes_3d, gt_labels_3d):
5156
return (center_target, size_res_target, dir_class_target,
5257
dir_res_target)
5358

54-
def decode(self, bbox_out):
59+
def decode(self, bbox_out: dict) -> Tensor:
5560
"""Decode predicted parts to bbox3d.
5661
5762
Args:
@@ -85,7 +90,8 @@ def decode(self, bbox_out):
8590
bbox3d = torch.cat([center, bbox_size, dir_angle], dim=-1)
8691
return bbox3d
8792

88-
def split_pred(self, cls_preds, reg_preds, base_xyz):
93+
def split_pred(self, cls_preds: Tensor, reg_preds: Tensor,
94+
base_xyz: Tensor) -> Dict[str, Tensor]:
8995
"""Split predicted features to specific parts.
9096
9197
Args:

mmdet3d/models/task_modules/coders/centerpoint_bbox_coders.py

Lines changed: 24 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,9 @@
11
# Copyright (c) OpenMMLab. All rights reserved.
2+
from typing import Dict, List, Optional, Tuple
3+
24
import torch
35
from mmdet.models.task_modules import BaseBBoxCoder
6+
from torch import Tensor
47

58
from mmdet3d.registry import TASK_UTILS
69

@@ -22,13 +25,13 @@ class CenterPointBBoxCoder(BaseBBoxCoder):
2225
"""
2326

2427
def __init__(self,
25-
pc_range,
26-
out_size_factor,
27-
voxel_size,
28-
post_center_range=None,
29-
max_num=100,
30-
score_threshold=None,
31-
code_size=9):
28+
pc_range: List[float],
29+
out_size_factor: int,
30+
voxel_size: List[float],
31+
post_center_range: Optional[List[float]] = None,
32+
max_num: int = 100,
33+
score_threshold: Optional[float] = None,
34+
code_size: int = 9) -> None:
3235

3336
self.pc_range = pc_range
3437
self.out_size_factor = out_size_factor
@@ -38,7 +41,10 @@ def __init__(self,
3841
self.score_threshold = score_threshold
3942
self.code_size = code_size
4043

41-
def _gather_feat(self, feats, inds, feat_masks=None):
44+
def _gather_feat(self,
45+
feats: Tensor,
46+
inds: Tensor,
47+
feat_masks: Optional[Tensor] = None) -> Tensor:
4248
"""Given feats and indexes, returns the gathered feats.
4349
4450
Args:
@@ -60,7 +66,7 @@ def _gather_feat(self, feats, inds, feat_masks=None):
6066
feats = feats.view(-1, dim)
6167
return feats
6268

63-
def _topk(self, scores, K=80):
69+
def _topk(self, scores: Tensor, K: int = 80) -> Tuple[Tensor]:
6470
"""Get indexes based on scores.
6571
6672
Args:
@@ -95,7 +101,7 @@ def _topk(self, scores, K=80):
95101

96102
return topk_score, topk_inds, topk_clses, topk_ys, topk_xs
97103

98-
def _transpose_and_gather_feat(self, feat, ind):
104+
def _transpose_and_gather_feat(self, feat: Tensor, ind: Tensor) -> Tensor:
99105
"""Given feats and indexes, returns the transposed and gathered feats.
100106
101107
Args:
@@ -115,14 +121,14 @@ def encode(self):
115121
pass
116122

117123
def decode(self,
118-
heat,
119-
rot_sine,
120-
rot_cosine,
121-
hei,
122-
dim,
123-
vel,
124-
reg=None,
125-
task_id=-1):
124+
heat: Tensor,
125+
rot_sine: Tensor,
126+
rot_cosine: Tensor,
127+
hei: Tensor,
128+
dim: Tensor,
129+
vel: Tensor,
130+
reg: Optional[Tensor] = None,
131+
task_id: int = -1) -> List[Dict[str, Tensor]]:
126132
"""Decode bboxes.
127133
128134
Args:

0 commit comments

Comments
 (0)