88from mmdeploy .codebase .mmdet import get_post_processing_params
99from mmdeploy .core import FUNCTION_REWRITER
1010from mmdeploy .mmcv .ops .nms import multiclass_nms
11- from mmdeploy .utils import Backend
12- from mmdeploy .utils .config_utils import get_backend_config
11+ from mmdeploy .utils import Backend , get_backend
1312
1413
1514@FUNCTION_REWRITER .register_rewriter (func_name = 'models.yolox_pose_head.'
@@ -166,9 +165,9 @@ def yolox_pose_head__predict_by_feat(
166165
167166 pred_kpts = torch .cat ([flatten_decoded_kpts , vis_preds ], dim = 3 )
168167
169- backend_config = get_backend_config (deploy_cfg )
170- if backend_config . type == Backend .TENSORRT . value :
171- # pad
168+ backend = get_backend (deploy_cfg )
169+ if backend == Backend .TENSORRT :
170+ # pad for batched_nms because its output index is filled with -1
172171 bboxes = torch .cat (
173172 [bboxes ,
174173 bboxes .new_zeros ((bboxes .shape [0 ], 1 , bboxes .shape [2 ]))],
@@ -188,7 +187,7 @@ def yolox_pose_head__predict_by_feat(
188187 iou_threshold = cfg .nms .get ('iou_threshold' , post_params .iou_threshold )
189188 score_threshold = cfg .get ('score_thr' , post_params .score_threshold )
190189 pre_top_k = post_params .get ('pre_top_k' , - 1 )
191- keep_top_k = post_params .get ('keep_top_k ' , - 1 )
190+ keep_top_k = cfg .get ('max_per_img ' , post_params . keep_top_k )
192191 # do nms
193192 _ , _ , nms_indices = multiclass_nms (
194193 bboxes ,
0 commit comments