Skip to content

Commit 26a681d

Browse files
authored
Fix KP Detection performance (#4270)
* fix kp det performance, change export, resize and augmentations * move pad out of the resize * remove second model, add decode_scores * remove second config for RTMPose: * rename kp template * fix visual prompting * fix torchvision unit test * add unit test to rescale_kp * delete redundant test
1 parent fd3fa0c commit 26a681d

File tree

15 files changed

+214
-357
lines changed

15 files changed

+214
-357
lines changed

src/otx/algo/keypoint_detection/rtmpose.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -43,13 +43,12 @@ def _exporter(self) -> OTXModelExporter:
4343
if self.explain_mode:
4444
msg = "Export with explain is not supported for RTMPose model."
4545
logger.warning(msg)
46-
4746
return OTXNativeModelExporter(
4847
task_level_export_parameters=self._export_parameters,
4948
input_size=(1, 3, *self.input_size),
5049
mean=self.mean,
5150
std=self.std,
52-
resize_mode="standard",
51+
resize_mode="fit_to_window",
5352
pad_value=0,
5453
swap_rgb=False,
5554
via_onnx=True,
@@ -115,6 +114,7 @@ def _build_model(self, num_classes: int) -> RTMPose:
115114
"sigma": sigma,
116115
"normalize": False,
117116
"use_dark": False,
117+
"decode_scores": True,
118118
},
119119
gau_cfg={
120120
"num_token": num_classes,

src/otx/algo/keypoint_detection/utils/simcc_label.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -67,7 +67,7 @@ def __init__(
6767
label_smooth_weight: float = 0.0,
6868
normalize: bool = True,
6969
use_dark: bool = False,
70-
decode_visibility: bool = False,
70+
decode_scores: bool = True,
7171
decode_beta: float = 150.0,
7272
) -> None:
7373
self.input_size = input_size
@@ -76,7 +76,7 @@ def __init__(
7676
self.label_smooth_weight = label_smooth_weight
7777
self.normalize = normalize
7878
self.use_dark = use_dark
79-
self.decode_visibility = decode_visibility
79+
self.decode_scores = decode_scores
8080
self.decode_beta = decode_beta
8181

8282
if isinstance(sigma, (float, int)):
@@ -170,13 +170,13 @@ def decode(self, simcc_x: np.ndarray, simcc_y: np.ndarray) -> tuple[np.ndarray,
170170

171171
keypoints /= self.simcc_split_ratio
172172

173-
if self.decode_visibility:
173+
if self.decode_scores:
174174
_, visibility = get_simcc_maximum(
175175
simcc_x * self.decode_beta * self.sigma[0],
176176
simcc_y * self.decode_beta * self.sigma[1],
177177
apply_softmax=True,
178178
)
179-
return keypoints, (scores, visibility)
179+
return keypoints, visibility
180180
return keypoints, scores
181181

182182
def _map_coordinates(

src/otx/core/data/dataset/keypoint_detection.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -108,7 +108,12 @@ def _get_item_impl(self, index: int) -> KeypointDetDataEntity | None:
108108
if len(keypoint_anns) > 0
109109
else np.zeros((0, len(self.label_info.label_names) * 2), dtype=np.float32)
110110
).reshape(-1, 2)
111-
keypoints_visible = np.minimum(1, keypoints)[..., 0]
111+
112+
keypoints_visible = (
113+
(np.array([ann.visibility for ann in keypoint_anns]) > 1).reshape(-1).astype(np.int8)
114+
if len(keypoint_anns) > 0 and hasattr(keypoint_anns[0], "visibility")
115+
else np.minimum(1, keypoints)[..., 0]
116+
)
112117

113118
bbox_center = np.array(img_shape) / 2.0
114119
bbox_scale = np.array(img_shape)

0 commit comments

Comments
 (0)