Skip to content

Commit 15746ea

Browse files
authored
Add Keypoint Detection legacy template (#4094)
added rtmpose_template
1 parent dc882bf commit 15746ea

File tree

4 files changed

+541
-6
lines changed

4 files changed

+541
-6
lines changed

src/otx/core/model/keypoint_detection.py

Lines changed: 35 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111
import torch
1212

1313
from otx.algo.utils.mmengine_utils import load_checkpoint
14-
from otx.core.data.entity.base import OTXBatchLossEntity
14+
from otx.core.data.entity.base import ImageInfo, OTXBatchLossEntity
1515
from otx.core.data.entity.keypoint_detection import KeypointDetBatchDataEntity, KeypointDetBatchPredEntity
1616
from otx.core.metrics import MetricCallable, MetricInput
1717
from otx.core.metrics.pck import PCKMeasureCallable
@@ -150,6 +150,40 @@ def forward_for_tracing(self, image: torch.Tensor) -> torch.Tensor | tuple[torch
150150
"""Model forward function used for the model tracing during model exportation."""
151151
return self.model.forward(inputs=image, mode="tensor")
152152

153+
def get_dummy_input(self, batch_size: int = 1) -> KeypointDetBatchDataEntity:
154+
"""Generates a dummy input, suitable for launching forward() on it.
155+
156+
Args:
157+
batch_size (int, optional): number of elements in a dummy input sequence. Defaults to 1.
158+
159+
Returns:
160+
KeypointDetBatchDataEntity: An entity containing randomly generated inference data.
161+
"""
162+
if self.input_size is None:
163+
msg = f"Input size attribute is not set for {self.__class__}"
164+
raise ValueError(msg)
165+
166+
images = torch.rand(batch_size, 3, *self.input_size)
167+
infos = []
168+
for i, img in enumerate(images):
169+
infos.append(
170+
ImageInfo(
171+
img_idx=i,
172+
img_shape=img.shape,
173+
ori_shape=img.shape,
174+
),
175+
)
176+
return KeypointDetBatchDataEntity(
177+
batch_size,
178+
images,
179+
infos,
180+
bboxes=[],
181+
labels=[],
182+
bbox_info=[],
183+
keypoints=[],
184+
keypoints_visible=[],
185+
)
186+
153187
@property
154188
def _export_parameters(self) -> TaskLevelExportParameters:
155189
"""Defines parameters required to export a particular model implementation."""

src/otx/tools/converter.py

Lines changed: 1 addition & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -184,14 +184,10 @@
184184
"model_name": "stfpm",
185185
},
186186
# KEYPOINT_DETECTION
187-
"Custom_Keypoint_Detection_Rtmpose_T": {
187+
"Keypoint_Detection_RTMPose_Tiny": {
188188
"task": OTXTaskType.KEYPOINT_DETECTION,
189189
"model_name": "rtmpose_tiny",
190190
},
191-
"Custom_Keypoint_Detection_Rtmpose_T_Single_Obj": {
192-
"task": OTXTaskType.KEYPOINT_DETECTION,
193-
"model_name": "rtmpose_tiny_single_obj",
194-
},
195191
}
196192

197193

0 commit comments

Comments
 (0)