22from typing import List , Optional , Tuple
33
44import torch
5- from mmengine .config import ConfigDict
65from torch import Tensor
76
87from mmdeploy .codebase .mmdet import get_post_processing_params
1110from mmdeploy .utils import Backend , get_backend
1211
1312
14- @FUNCTION_REWRITER .register_rewriter (func_name = 'models.yolox_pose_head.'
15- 'YOLOXPoseHead.predict' )
13+ @FUNCTION_REWRITER .register_rewriter (
14+ func_name = 'mmpose.models.heads.hybrid_heads.'
15+ 'yoloxpose_head.YOLOXPoseHead.forward' )
1616def predict (self ,
1717 x : Tuple [Tensor ],
18- batch_data_samples = None ,
19- rescale : bool = True ):
18+ batch_data_samples : List = [] ,
19+ test_cfg : Optional [ dict ] = None ):
2020 """Get predictions and transform to bbox and keypoints results.
2121 Args:
2222 x (Tuple[Tensor]): The input tensor from upstream network.
2323 batch_data_samples: Batch image meta info. Defaults to None.
24- rescale: If True, return boxes in original image space.
25- Defaults to False.
24+ test_cfg: The runtime config for testing process.
2625
2726 Returns:
2827 Tuple[Tensor]: Predict bbox and keypoint results.
@@ -33,73 +32,17 @@ def predict(self,
3332 Tensor, has shape (batch_size, num_instances, num_keypoints, 5),
3433 the last dimension 3 arrange as (x, y, score).
3534 """
36- outs = self (x )
37- predictions = self .predict_by_feat (
38- * outs , batch_img_metas = batch_data_samples , rescale = rescale )
39- return predictions
40-
41-
42- @FUNCTION_REWRITER .register_rewriter (func_name = 'models.yolox_pose_head.'
43- 'YOLOXPoseHead.predict_by_feat' )
44- def yolox_pose_head__predict_by_feat (
45- self ,
46- cls_scores : List [Tensor ],
47- bbox_preds : List [Tensor ],
48- objectnesses : Optional [List [Tensor ]] = None ,
49- kpt_preds : Optional [List [Tensor ]] = None ,
50- vis_preds : Optional [List [Tensor ]] = None ,
51- batch_img_metas : Optional [List [dict ]] = None ,
52- cfg : Optional [ConfigDict ] = None ,
53- rescale : bool = True ,
54- with_nms : bool = True ) -> Tuple [Tensor ]:
55- """Transform a batch of output features extracted by the head into bbox and
56- keypoint results.
57-
58- In addition to the base class method, keypoint predictions are also
59- calculated in this method.
6035
61- Args:
62- cls_scores (List[Tensor]): Classification scores for all
63- scale levels, each is a 4D-tensor, has shape
64- (batch_size, num_priors * num_classes, H, W).
65- bbox_preds (List[Tensor]): Box energies / deltas for all
66- scale levels, each is a 4D-tensor, has shape
67- (batch_size, num_priors * 4, H, W).
68- objectnesses (Optional[List[Tensor]]): Score factor for
69- all scale level, each is a 4D-tensor, has shape
70- (batch_size, 1, H, W).
71- kpt_preds (Optional[List[Tensor]]): Keypoints for all
72- scale levels, each is a 4D-tensor, has shape
73- (batch_size, num_keypoints * 2, H, W)
74- vis_preds (Optional[List[Tensor]]): Keypoints scores for
75- all scale levels, each is a 4D-tensor, has shape
76- (batch_size, num_keypoints, H, W)
77- batch_img_metas (Optional[List[dict]]): Batch image meta
78- info. Defaults to None.
79- cfg (Optional[ConfigDict]): Test / postprocessing
80- configuration, if None, test_cfg would be used.
81- Defaults to None.
82- rescale (bool): If True, return boxes in original image space.
83- Defaults to False.
84- with_nms (bool): If True, do nms before return boxes.
85- Defaults to True.
86- Returns:
87- Tuple[Tensor]: Predict bbox and keypoint results.
88- - dets (Tensor): Predict bboxes and scores, which is a 3D Tensor,
89- has shape (batch_size, num_instances, 5), the last dimension 5
90- arrange as (x1, y1, x2, y2, score).
91- - pred_kpts (Tensor): Predict keypoints and scores, which is a 4D
92- Tensor, has shape (batch_size, num_instances, num_keypoints, 5),
93- the last dimension 3 arrange as (x, y, score).
94- """
36+ cls_scores , objectnesses , bbox_preds , kpt_offsets , \
37+ kpt_vis = self .head_module (x )[:5 ]
38+
9539 ctx = FUNCTION_REWRITER .get_context ()
9640 deploy_cfg = ctx .cfg
9741 dtype = cls_scores [0 ].dtype
9842 device = cls_scores [0 ].device
99- bbox_decoder = self .bbox_coder .decode
10043
10144 assert len (cls_scores ) == len (bbox_preds )
102- cfg = self .test_cfg if cfg is None else cfg
45+ cfg = self .test_cfg if test_cfg is None else test_cfg
10346
10447 num_imgs = cls_scores [0 ].shape [0 ]
10548 featmap_sizes = [cls_score .shape [2 :] for cls_score in cls_scores ]
@@ -110,60 +53,27 @@ def yolox_pose_head__predict_by_feat(
11053 flatten_priors = torch .cat (self .mlvl_priors )
11154
11255 mlvl_strides = [
113- flatten_priors .new_full (
114- (featmap_size [0 ] * featmap_size [1 ] * self .num_base_priors , ),
115- stride )
56+ flatten_priors .new_full ((featmap_size .numel (), ), stride )
11657 for featmap_size , stride in zip (featmap_sizes , self .featmap_strides )
11758 ]
11859 flatten_stride = torch .cat (mlvl_strides )
11960
12061 # flatten cls_scores, bbox_preds and objectness
121- flatten_cls_scores = [
122- cls_score .permute (0 , 2 , 3 , 1 ).reshape (num_imgs , - 1 , self .num_classes )
123- for cls_score in cls_scores
124- ]
125- cls_scores = torch .cat (flatten_cls_scores , dim = 1 ).sigmoid ()
126-
127- flatten_bbox_preds = [
128- bbox_pred .permute (0 , 2 , 3 , 1 ).reshape (num_imgs , - 1 , 4 )
129- for bbox_pred in bbox_preds
130- ]
131- flatten_bbox_preds = torch .cat (flatten_bbox_preds , dim = 1 )
132-
133- if objectnesses is not None :
134- flatten_objectness = [
135- objectness .permute (0 , 2 , 3 , 1 ).reshape (num_imgs , - 1 )
136- for objectness in objectnesses
137- ]
138- flatten_objectness = torch .cat (flatten_objectness , dim = 1 ).sigmoid ()
139- cls_scores = cls_scores * (flatten_objectness .unsqueeze (- 1 ))
140-
141- scores = cls_scores
142- bboxes = bbox_decoder (flatten_priors [None ], flatten_bbox_preds ,
143- flatten_stride )
144-
145- # deal with key-poinsts
146- priors = torch .cat (self .mlvl_priors )
147- strides = [
148- priors .new_full ((featmap_size .numel () * self .num_base_priors , ),
149- stride )
150- for featmap_size , stride in zip (featmap_sizes , self .featmap_strides )
151- ]
152- strides = torch .cat (strides )
153- kpt_preds = torch .cat ([
154- kpt_pred .permute (0 , 2 , 3 , 1 ).reshape (
155- num_imgs , - 1 , self .num_keypoints * 2 ) for kpt_pred in kpt_preds
156- ],
157- dim = 1 )
158- flatten_decoded_kpts = self .decode_pose (priors , kpt_preds , strides )
159-
160- vis_preds = torch .cat ([
161- vis_pred .permute (0 , 2 , 3 , 1 ).reshape (num_imgs , - 1 , self .num_keypoints ,
162- 1 ) for vis_pred in vis_preds
163- ],
164- dim = 1 ).sigmoid ()
165-
166- pred_kpts = torch .cat ([flatten_decoded_kpts , vis_preds ], dim = 3 )
62+ flatten_cls_scores = self ._flatten_predictions (cls_scores ).sigmoid ()
63+ flatten_bbox_preds = self ._flatten_predictions (bbox_preds )
64+ flatten_objectness = self ._flatten_predictions (objectnesses ).sigmoid ()
65+ flatten_kpt_offsets = self ._flatten_predictions (kpt_offsets )
66+ flatten_kpt_vis = self ._flatten_predictions (kpt_vis ).sigmoid ()
67+ bboxes = self .decode_bbox (flatten_bbox_preds , flatten_priors ,
68+ flatten_stride )
69+ flatten_decoded_kpts = self .decode_kpt_reg (flatten_kpt_offsets ,
70+ flatten_priors , flatten_stride )
71+
72+ scores = flatten_cls_scores * flatten_objectness
73+
74+ pred_kpts = torch .cat ([flatten_decoded_kpts ,
75+ flatten_kpt_vis .unsqueeze (3 )],
76+ dim = 3 )
16777
16878 backend = get_backend (deploy_cfg )
16979 if backend == Backend .TENSORRT :
@@ -184,10 +94,11 @@ def yolox_pose_head__predict_by_feat(
18494 # nms
18595 post_params = get_post_processing_params (deploy_cfg )
18696 max_output_boxes_per_class = post_params .max_output_boxes_per_class
187- iou_threshold = cfg .nms . get ('iou_threshold ' , post_params .iou_threshold )
97+ iou_threshold = cfg .get ('nms_thr ' , post_params .iou_threshold )
18898 score_threshold = cfg .get ('score_thr' , post_params .score_threshold )
18999 pre_top_k = post_params .get ('pre_top_k' , - 1 )
190100 keep_top_k = cfg .get ('max_per_img' , post_params .keep_top_k )
101+
191102 # do nms
192103 _ , _ , nms_indices = multiclass_nms (
193104 bboxes ,
0 commit comments