Skip to content

Commit ebd6b75

Browse files
authored
Fix torch2onnx for pointpillars with multi-level outputs (#2210)
* temp fix * fix * update
1 parent d7489c8 commit ebd6b75

File tree

3 files changed

+43
-10
lines changed

3 files changed

+43
-10
lines changed

configs/mmdet3d/voxel-detection/voxel-detection_static.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,4 +3,5 @@
33
type='mmdet3d', task='VoxelDetection', model_type='end2end')
44
onnx_config = dict(
55
input_names=['voxels', 'num_points', 'coors'],
6-
output_names=['cls_score', 'bbox_pred', 'dir_cls_pred'])
6+
# need to change output_names for head with multi-level features
7+
output_names=['cls_score0', 'bbox_pred0', 'dir_cls_pred0'])

mmdeploy/codebase/mmdet3d/deploy/voxel_detection_model.py

Lines changed: 11 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -90,7 +90,14 @@ def forward(self,
9090
}
9191

9292
outputs = self.wrapper(input_dict)
93-
93+
num_level = len(outputs) // 3
94+
new_outputs = dict(
95+
cls_score=[outputs[f'cls_score{i}'] for i in range(num_level)],
96+
bbox_pred=[outputs[f'bbox_pred{i}'] for i in range(num_level)],
97+
dir_cls_pred=[
98+
outputs[f'dir_cls_pred{i}'] for i in range(num_level)
99+
])
100+
outputs = new_outputs
94101
if data_samples is None:
95102
return outputs
96103

@@ -239,9 +246,9 @@ def postprocess(model_cfg: Union[str, Config],
239246

240247
if not hasattr(head, 'task_heads'):
241248
data_instances_3d = head.predict_by_feat(
242-
cls_scores=[cls_score],
243-
bbox_preds=[bbox_pred],
244-
dir_cls_preds=[dir_cls_pred],
249+
cls_scores=cls_score,
250+
bbox_preds=bbox_pred,
251+
dir_cls_preds=dir_cls_pred,
245252
batch_input_metas=batch_input_metas,
246253
cfg=cfg)
247254

mmdeploy/codebase/mmdet3d/models/mvx_two_stage.py

Lines changed: 30 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
import torch
33

44
from mmdeploy.core import FUNCTION_REWRITER
5+
from mmdeploy.utils import get_ir_config
56

67

78
@FUNCTION_REWRITER.register_rewriter(
@@ -52,10 +53,21 @@ def mvxtwostagedetector__forward(self, inputs: list, **kwargs):
5253
inputs (list): input list comprises voxels, num_points and coors
5354
5455
Returns:
55-
bbox (Tensor): Decoded bbox after nms
56-
scores (Tensor): bbox scores
57-
labels (Tensor): bbox labels
56+
tuple: A tuple of classification scores, bbox and direction
57+
classification prediction.
58+
59+
- cls_scores (list[Tensor]): Classification scores for all
60+
scale levels, each is a 4D-tensor, the channels number
61+
is num_base_priors * num_classes.
62+
- bbox_preds (list[Tensor]): Box energies / deltas for all
63+
scale levels, each is a 4D-tensor, the channels number
64+
is num_base_priors * C.
65+
- dir_cls_preds (list[Tensor|None]): Direction classification
66+
predictions for all scale levels, each is a 4D-tensor,
67+
the channels number is num_base_priors * 2.
5868
"""
69+
ctx = FUNCTION_REWRITER.get_context()
70+
deploy_cfg = ctx.cfg
5971
batch_inputs_dict = {
6072
'voxels': {
6173
'voxels': inputs[0],
@@ -82,5 +94,18 @@ def mvxtwostagedetector__forward(self, inputs: list, **kwargs):
8294
dir_scores = torch.cat(dir_scores, dim=1)
8395
return scores, bbox_preds, dir_scores
8496
else:
85-
cls_score, bbox_pred, dir_cls_pred = outs[0][0], outs[1][0], outs[2][0]
86-
return cls_score, bbox_pred, dir_cls_pred
97+
preds = []
98+
expect_names = []
99+
for i in range(len(outs[0])):
100+
preds += [outs[0][i], outs[1][i], outs[2][i]]
101+
expect_names += [
102+
f'cls_score{i}', f'bbox_pred{i}', f'dir_cls_pred{i}'
103+
]
104+
# check if output_names is set correctly.
105+
onnx_cfg = get_ir_config(deploy_cfg)
106+
output_names = onnx_cfg['output_names']
107+
if output_names != list(expect_names):
108+
raise RuntimeError(f'`output_names` should be {expect_names} '
109+
f'but given {output_names}\n'
110+
f'Deploy config:\n{deploy_cfg.pretty_text}')
111+
return tuple(preds)

0 commit comments

Comments
 (0)