2121import numpy as np
2222
2323from .image_model import ImageModel
24- from .types import ListValue
24+ from .types import BooleanValue , ListValue
2525from .utils import DetectedKeypoints , Detection
2626
2727
@@ -59,7 +59,9 @@ def postprocess(
5959 DetectedKeypoints: detected keypoints
6060 """
6161 encoded_kps = list (outputs .values ())
62- batch_keypoints , batch_scores = _decode_simcc (* encoded_kps )
62+ batch_keypoints , batch_scores = _decode_simcc (
63+ * encoded_kps , apply_softmax = self .apply_softmax
64+ )
6365 orig_h , orig_w = meta ["original_shape" ][:2 ]
6466 kp_scale_h = orig_h / self .h
6567 kp_scale_w = orig_w / self .w
@@ -74,6 +76,10 @@ def parameters(cls) -> dict:
7476 "labels" : ListValue (
7577 description = "List of class labels" , value_type = str , default_value = []
7678 ),
79+ "apply_softmax" : BooleanValue (
80+ default_value = True ,
81+ description = "Whether to apply softmax on the heatmap." ,
82+ ),
7783 }
7884 )
7985 return parameters
@@ -127,22 +133,27 @@ def predict_crops(self, crops: list[np.ndarray]) -> list[DetectedKeypoints]:
127133
128134
129135def _decode_simcc (
130- simcc_x : np .ndarray , simcc_y : np .ndarray , simcc_split_ratio : float = 2.0
136+ simcc_x : np .ndarray ,
137+ simcc_y : np .ndarray ,
138+ simcc_split_ratio : float = 2.0 ,
139+ apply_softmax : bool = False ,
131140) -> tuple [np .ndarray , np .ndarray ]:
132141 """Decodes keypoint coordinates from SimCC representations. The decoded coordinates are in the input image space.
133142
134143 Args:
135144 simcc_x (np.ndarray): SimCC label for x-axis
136145 simcc_y (np.ndarray): SimCC label for y-axis
137146 simcc_split_ratio (float): The ratio of the label size to the input size.
147+ apply_softmax (bool): whether to apply softmax on the heatmap.
148+ Defaults to False.
138149
139150 Returns:
140151 tuple:
141152 - keypoints (np.ndarray): Decoded coordinates in shape (N, K, D)
142153 - scores (np.ndarray): The keypoint scores in shape (N, K).
143154 It usually represents the confidence of the keypoint prediction
144155 """
145- keypoints , scores = _get_simcc_maximum (simcc_x , simcc_y )
156+ keypoints , scores = _get_simcc_maximum (simcc_x , simcc_y , apply_softmax )
146157
147158 # Unsqueeze the instance dimension for single-instance results
148159 if keypoints .ndim == 2 :
@@ -157,6 +168,7 @@ def _decode_simcc(
157168def _get_simcc_maximum (
158169 simcc_x : np .ndarray ,
159170 simcc_y : np .ndarray ,
171+ apply_softmax : bool = False ,
160172) -> tuple [np .ndarray , np .ndarray ]:
161173 """Get maximum response location and value from simcc representations.
162174
@@ -169,6 +181,8 @@ def _get_simcc_maximum(
169181 Args:
170182 simcc_x (np.ndarray): x-axis SimCC in shape (K, Wx) or (N, K, Wx)
171183 simcc_y (np.ndarray): y-axis SimCC in shape (K, Hy) or (N, K, Hy)
184+ apply_softmax (bool): whether to apply softmax on the heatmap.
185+ Defaults to False.
172186
173187 Returns:
174188 tuple:
@@ -194,6 +208,13 @@ def _get_simcc_maximum(
194208 else :
195209 batch_size = None
196210
211+ if apply_softmax :
212+ simcc_x = simcc_x - np .max (simcc_x , axis = 1 , keepdims = True )
213+ simcc_y = simcc_y - np .max (simcc_y , axis = 1 , keepdims = True )
214+ ex , ey = np .exp (simcc_x ), np .exp (simcc_y )
215+ simcc_x = ex / np .sum (ex , axis = 1 , keepdims = True )
216+ simcc_y = ey / np .sum (ey , axis = 1 , keepdims = True )
217+
197218 x_locs = np .argmax (simcc_x , axis = 1 )
198219 y_locs = np .argmax (simcc_y , axis = 1 )
199220 locs = np .stack ((x_locs , y_locs ), axis = - 1 ).astype (np .float32 )
0 commit comments