Skip to content

Commit 3ab17f5

Browse files
authored
Fix reg test for maskrcnn (#2230)
1 parent e19f6fa commit 3ab17f5

File tree

2 files changed

+8
-8
lines changed

2 files changed

+8
-8
lines changed

mmdeploy/codebase/mmpose/models/heads/yolox_pose_head.py

Lines changed: 5 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -8,8 +8,7 @@
88
from mmdeploy.codebase.mmdet import get_post_processing_params
99
from mmdeploy.core import FUNCTION_REWRITER
1010
from mmdeploy.mmcv.ops.nms import multiclass_nms
11-
from mmdeploy.utils import Backend
12-
from mmdeploy.utils.config_utils import get_backend_config
11+
from mmdeploy.utils import Backend, get_backend
1312

1413

1514
@FUNCTION_REWRITER.register_rewriter(func_name='models.yolox_pose_head.'
@@ -166,9 +165,9 @@ def yolox_pose_head__predict_by_feat(
166165

167166
pred_kpts = torch.cat([flatten_decoded_kpts, vis_preds], dim=3)
168167

169-
backend_config = get_backend_config(deploy_cfg)
170-
if backend_config.type == Backend.TENSORRT.value:
171-
# pad
168+
backend = get_backend(deploy_cfg)
169+
if backend == Backend.TENSORRT:
170+
# pad for batched_nms because its output index is filled with -1
172171
bboxes = torch.cat(
173172
[bboxes,
174173
bboxes.new_zeros((bboxes.shape[0], 1, bboxes.shape[2]))],
@@ -188,7 +187,7 @@ def yolox_pose_head__predict_by_feat(
188187
iou_threshold = cfg.nms.get('iou_threshold', post_params.iou_threshold)
189188
score_threshold = cfg.get('score_thr', post_params.score_threshold)
190189
pre_top_k = post_params.get('pre_top_k', -1)
191-
keep_top_k = post_params.get('keep_top_k', -1)
190+
keep_top_k = cfg.get('max_per_img', post_params.keep_top_k)
192191
# do nms
193192
_, _, nms_indices = multiclass_nms(
194193
bboxes,

tools/regression_test.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -304,8 +304,9 @@ def get_pytorch_result(model_name: str, meta_info: dict, checkpoint_path: Path,
304304
task_name = metafile_metric['Task']
305305
dataset = metafile_metric['Dataset']
306306

307-
# check if metafile use the same metric on several datasets
308-
if len(metafile_metric_info) > 1:
307+
# check if metafile use the same metric on several datasets for mmagic
308+
task_info = set([_['Task'] for _ in metafile_metric_info])
309+
if len(metafile_metric_info) > 1 and len(task_info) == 1:
309310
for k, v in metafile_metric['Metrics'].items():
310311
pytorch_metric[f'{dataset} {k}'] = v
311312
else:

0 commit comments

Comments
 (0)