22import torch
33
44from mmdeploy .core import FUNCTION_REWRITER
5+ from mmdeploy .utils import get_ir_config
56
67
78@FUNCTION_REWRITER .register_rewriter (
@@ -52,10 +53,21 @@ def mvxtwostagedetector__forward(self, inputs: list, **kwargs):
5253 inputs (list): input list comprises voxels, num_points and coors
5354
5455 Returns:
55- bbox (Tensor): Decoded bbox after nms
56- scores (Tensor): bbox scores
57- labels (Tensor): bbox labels
56+ tuple: A tuple of classification scores, bbox and direction
57+ classification prediction.
58+
59+ - cls_scores (list[Tensor]): Classification scores for all
60+ scale levels, each is a 4D-tensor, the channels number
61+ is num_base_priors * num_classes.
62+ - bbox_preds (list[Tensor]): Box energies / deltas for all
63+ scale levels, each is a 4D-tensor, the channels number
64+ is num_base_priors * C.
65+ - dir_cls_preds (list[Tensor|None]): Direction classification
66+ predictions for all scale levels, each is a 4D-tensor,
67+ the channels number is num_base_priors * 2.
5868 """
69+ ctx = FUNCTION_REWRITER .get_context ()
70+ deploy_cfg = ctx .cfg
5971 batch_inputs_dict = {
6072 'voxels' : {
6173 'voxels' : inputs [0 ],
@@ -82,5 +94,18 @@ def mvxtwostagedetector__forward(self, inputs: list, **kwargs):
8294 dir_scores = torch .cat (dir_scores , dim = 1 )
8395 return scores , bbox_preds , dir_scores
8496 else :
85- cls_score , bbox_pred , dir_cls_pred = outs [0 ][0 ], outs [1 ][0 ], outs [2 ][0 ]
86- return cls_score , bbox_pred , dir_cls_pred
97+ preds = []
98+ expect_names = []
99+ for i in range (len (outs [0 ])):
100+ preds += [outs [0 ][i ], outs [1 ][i ], outs [2 ][i ]]
101+ expect_names += [
102+ f'cls_score{ i } ' , f'bbox_pred{ i } ' , f'dir_cls_pred{ i } '
103+ ]
104+ # check if output_names is set correctly.
105+ onnx_cfg = get_ir_config (deploy_cfg )
106+ output_names = onnx_cfg ['output_names' ]
107+ if output_names != list (expect_names ):
108+ raise RuntimeError (f'`output_names` should be { expect_names } '
109+ f'but given { output_names } \n '
110+ f'Deploy config:\n { deploy_cfg .pretty_text } ' )
111+ return tuple (preds )
0 commit comments