Skip to content

Commit df0bf8f

Browse files
authored
Fix yolox-pose ut (#2231)
* update yolox-pose ut * fix lint * fix
1 parent 3ab17f5 commit df0bf8f

File tree

1 file changed

+90
-9
lines changed

1 file changed

+90
-9
lines changed

tests/test_codebase/test_mmpose/test_mmpose_models.py

Lines changed: 90 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,8 @@
22
import mmengine
33
import pytest
44
import torch
5+
from mmengine.config import ConfigDict
6+
from mmengine.structures import InstanceData
57

68
from mmdeploy.codebase import import_codebase
79
from mmdeploy.utils import Backend, Codebase
@@ -194,15 +196,18 @@ def test_scale_forward(backend_type: Backend):
194196
@pytest.mark.parametrize('backend_type', [Backend.ONNXRUNTIME])
195197
def test_yolox_pose_head(backend_type: Backend):
196198
try:
197-
from models import yolox_pose_head # noqa: F401,F403
199+
from mmyolo.utils.setup_env import register_all_modules
200+
from models.yolox_pose_head import YOLOXPoseHead # noqa: F401,F403
201+
register_all_modules(True)
198202
except ImportError:
199203
pytest.skip(
200204
'mmpose/projects/yolox-pose is not installed.',
201205
allow_module_level=True)
202206
deploy_cfg = mmengine.Config.fromfile(
203207
'configs/mmpose/pose-detection_yolox-pose_onnxruntime_dynamic.py')
204208
check_backend(backend_type, True)
205-
model = yolox_pose_head.YOLOXPoseHead(
209+
210+
head = YOLOXPoseHead(
206211
head_module=dict(
207212
type='YOLOXPoseHeadModule',
208213
num_classes=1,
@@ -215,19 +220,95 @@ def test_yolox_pose_head(backend_type: Backend):
215220
use_depthwise=False,
216221
norm_cfg=dict(type='BN', momentum=0.03, eps=0.001),
217222
act_cfg=dict(type='SiLU', inplace=True),
218-
))
223+
),
224+
loss_cls=dict(
225+
type='mmdet.CrossEntropyLoss',
226+
use_sigmoid=True,
227+
reduction='sum',
228+
loss_weight=1.0),
229+
loss_bbox=dict(
230+
type='mmdet.IoULoss',
231+
mode='square',
232+
eps=1e-16,
233+
reduction='sum',
234+
loss_weight=5.0),
235+
loss_obj=dict(
236+
type='mmdet.CrossEntropyLoss',
237+
use_sigmoid=True,
238+
reduction='sum',
239+
loss_weight=1.0),
240+
loss_pose=dict(
241+
type='OksLoss',
242+
metainfo='configs/_base_/datasets/coco.py',
243+
loss_weight=30.0),
244+
loss_bbox_aux=dict(
245+
type='mmdet.L1Loss', reduction='sum', loss_weight=1.0),
246+
train_cfg=ConfigDict(
247+
assigner=dict(
248+
type='PoseSimOTAAssigner',
249+
center_radius=2.5,
250+
iou_calculator=dict(type='mmdet.BboxOverlaps2D'),
251+
oks_calculator=dict(
252+
type='OksLoss',
253+
metainfo='configs/_base_/datasets/coco.py'))),
254+
test_cfg=ConfigDict(
255+
yolox_style=True,
256+
multi_label=False,
257+
score_thr=0.001,
258+
max_per_img=300,
259+
nms=dict(type='nms', iou_threshold=0.65)))
260+
261+
class TestYOLOXPoseHeadModel(torch.nn.Module):
262+
263+
def __init__(self, yolox_pose_head):
264+
super(TestYOLOXPoseHeadModel, self).__init__()
265+
self.yolox_pose_head = yolox_pose_head
266+
267+
def forward(self, x1, x2, x3):
268+
inputs = [x1, x2, x3]
269+
data_sample = InstanceData()
270+
data_sample.set_metainfo(
271+
dict(ori_shape=(640, 640), scale_factor=(1.0, 1.0)))
272+
return self.yolox_pose_head.predict(
273+
inputs, batch_data_samples=[data_sample])
274+
275+
model = TestYOLOXPoseHeadModel(head)
219276
model.cpu().eval()
277+
220278
model_inputs = [
221-
torch.randn(2, 128, 80, 80),
222-
torch.randn(2, 128, 40, 40),
223-
torch.randn(2, 128, 20, 20)
279+
torch.randn(1, 128, 8, 8),
280+
torch.randn(1, 128, 4, 4),
281+
torch.randn(1, 128, 2, 2)
282+
]
283+
284+
with torch.no_grad():
285+
pytorch_output = model(*model_inputs)[0]
286+
pred_bboxes = torch.from_numpy(pytorch_output.bboxes).unsqueeze(0)
287+
pred_bboxes_scores = torch.from_numpy(pytorch_output.scores).reshape(
288+
1, -1, 1)
289+
pred_kpts = torch.from_numpy(pytorch_output.keypoints).unsqueeze(0)
290+
pred_kpts_scores = torch.from_numpy(
291+
pytorch_output.keypoint_scores).unsqueeze(0).unsqueeze(-1)
292+
293+
pytorch_output = [
294+
torch.cat([pred_bboxes, pred_bboxes_scores], dim=-1),
295+
torch.cat([pred_kpts, pred_kpts_scores], dim=-1)
224296
]
225-
pytorch_output = model(model_inputs)
297+
226298
wrapped_model = WrapModel(model, 'forward')
227-
rewrite_inputs = {'inputs': model_inputs}
299+
rewrite_inputs = {
300+
'x1': model_inputs[0],
301+
'x2': model_inputs[1],
302+
'x3': model_inputs[2]
303+
}
304+
deploy_cfg.onnx_config.input_names = ['x1', 'x2', 'x3']
305+
228306
rewrite_outputs, _ = get_rewrite_outputs(
229307
wrapped_model=wrapped_model,
230308
model_inputs=rewrite_inputs,
231-
run_with_backend=False,
309+
run_with_backend=True,
232310
deploy_cfg=deploy_cfg)
311+
312+
# keep bbox coord >= 0
313+
rewrite_outputs[0] = rewrite_outputs[0].clamp(min=0)
233314
torch_assert_close(rewrite_outputs, pytorch_output)

0 commit comments

Comments
 (0)