2121import numpy as np
2222
2323from .image_model import ImageModel
24- from .types import ListValue
24+ from .types import BooleanValue , ListValue
2525from .utils import DetectedKeypoints , Detection
2626
2727
@@ -59,7 +59,7 @@ def postprocess(
5959 DetectedKeypoints: detected keypoints
6060 """
6161 encoded_kps = list (outputs .values ())
62- batch_keypoints , batch_scores = _decode_simcc (* encoded_kps )
62+ batch_keypoints , batch_scores = _decode_simcc (* encoded_kps , apply_softmax = self . apply_softmax )
6363 orig_h , orig_w = meta ["original_shape" ][:2 ]
6464 kp_scale_h = orig_h / self .h
6565 kp_scale_w = orig_w / self .w
@@ -74,6 +74,9 @@ def parameters(cls) -> dict:
7474 "labels" : ListValue (
7575 description = "List of class labels" , value_type = str , default_value = []
7676 ),
77+ "apply_softmax" : BooleanValue (
78+ default_value = True , description = "Whether to apply softmax on the heatmap."
79+ ),
7780 }
7881 )
7982 return parameters
@@ -127,22 +130,25 @@ def predict_crops(self, crops: list[np.ndarray]) -> list[DetectedKeypoints]:
127130
128131
129132def _decode_simcc (
130- simcc_x : np .ndarray , simcc_y : np .ndarray , simcc_split_ratio : float = 2.0
133+ simcc_x : np .ndarray , simcc_y : np .ndarray , simcc_split_ratio : float = 2.0 ,
134+ apply_softmax : bool = False ,
131135) -> tuple [np .ndarray , np .ndarray ]:
132136 """Decodes keypoint coordinates from SimCC representations. The decoded coordinates are in the input image space.
133137
134138 Args:
135139 simcc_x (np.ndarray): SimCC label for x-axis
136140 simcc_y (np.ndarray): SimCC label for y-axis
137141 simcc_split_ratio (float): The ratio of the label size to the input size.
142+ apply_softmax (bool): whether to apply softmax on the heatmap.
143+ Defaults to False.
138144
139145 Returns:
140146 tuple:
141147 - keypoints (np.ndarray): Decoded coordinates in shape (N, K, D)
142148 - scores (np.ndarray): The keypoint scores in shape (N, K).
143149 It usually represents the confidence of the keypoint prediction
144150 """
145- keypoints , scores = _get_simcc_maximum (simcc_x , simcc_y )
151+ keypoints , scores = _get_simcc_maximum (simcc_x , simcc_y , apply_softmax )
146152
147153 # Unsqueeze the instance dimension for single-instance results
148154 if keypoints .ndim == 2 :
@@ -157,6 +163,7 @@ def _decode_simcc(
157163def _get_simcc_maximum (
158164 simcc_x : np .ndarray ,
159165 simcc_y : np .ndarray ,
166+ apply_softmax : bool = False ,
160167) -> tuple [np .ndarray , np .ndarray ]:
161168 """Get maximum response location and value from simcc representations.
162169
@@ -169,6 +176,8 @@ def _get_simcc_maximum(
169176 Args:
170177 simcc_x (np.ndarray): x-axis SimCC in shape (K, Wx) or (N, K, Wx)
171178 simcc_y (np.ndarray): y-axis SimCC in shape (K, Hy) or (N, K, Hy)
179+ apply_softmax (bool): whether to apply softmax on the heatmap.
180+ Defaults to False.
172181
173182 Returns:
174183 tuple:
@@ -194,6 +203,13 @@ def _get_simcc_maximum(
194203 else :
195204 batch_size = None
196205
206+ if apply_softmax :
207+ simcc_x = simcc_x - np .max (simcc_x , axis = 1 , keepdims = True )
208+ simcc_y = simcc_y - np .max (simcc_y , axis = 1 , keepdims = True )
209+ ex , ey = np .exp (simcc_x ), np .exp (simcc_y )
210+ simcc_x = ex / np .sum (ex , axis = 1 , keepdims = True )
211+ simcc_y = ey / np .sum (ey , axis = 1 , keepdims = True )
212+
197213 x_locs = np .argmax (simcc_x , axis = 1 )
198214 y_locs = np .argmax (simcc_y , axis = 1 )
199215 locs = np .stack ((x_locs , y_locs ), axis = - 1 ).astype (np .float32 )
0 commit comments