|
11 | 11 | import torch |
12 | 12 |
|
13 | 13 | from otx.algo.utils.mmengine_utils import load_checkpoint |
14 | | -from otx.core.data.entity.base import OTXBatchLossEntity |
| 14 | +from otx.core.data.entity.base import ImageInfo, OTXBatchLossEntity |
15 | 15 | from otx.core.data.entity.keypoint_detection import KeypointDetBatchDataEntity, KeypointDetBatchPredEntity |
16 | 16 | from otx.core.metrics import MetricCallable, MetricInput |
17 | 17 | from otx.core.metrics.pck import PCKMeasureCallable |
@@ -150,6 +150,40 @@ def forward_for_tracing(self, image: torch.Tensor) -> torch.Tensor | tuple[torch |
150 | 150 | """Model forward function used for the model tracing during model exportation.""" |
151 | 151 | return self.model.forward(inputs=image, mode="tensor") |
152 | 152 |
|
| 153 | + def get_dummy_input(self, batch_size: int = 1) -> KeypointDetBatchDataEntity: |
| 154 | + """Generates a dummy input, suitable for launching forward() on it. |
| 155 | +
|
| 156 | + Args: |
| 157 | + batch_size (int, optional): number of elements in a dummy input sequence. Defaults to 1. |
| 158 | +
|
| 159 | + Returns: |
| 160 | + KeypointDetBatchDataEntity: An entity containing randomly generated inference data. |
| 161 | + """ |
| 162 | + if self.input_size is None: |
| 163 | + msg = f"Input size attribute is not set for {self.__class__}" |
| 164 | + raise ValueError(msg) |
| 165 | + |
| 166 | + images = torch.rand(batch_size, 3, *self.input_size) |
| 167 | + infos = [] |
| 168 | + for i, img in enumerate(images): |
| 169 | + infos.append( |
| 170 | + ImageInfo( |
| 171 | + img_idx=i, |
| 172 | + img_shape=img.shape, |
| 173 | + ori_shape=img.shape, |
| 174 | + ), |
| 175 | + ) |
| 176 | + return KeypointDetBatchDataEntity( |
| 177 | + batch_size, |
| 178 | + images, |
| 179 | + infos, |
| 180 | + bboxes=[], |
| 181 | + labels=[], |
| 182 | + bbox_info=[], |
| 183 | + keypoints=[], |
| 184 | + keypoints_visible=[], |
| 185 | + ) |
| 186 | + |
153 | 187 | @property |
154 | 188 | def _export_parameters(self) -> TaskLevelExportParameters: |
155 | 189 | """Defines parameters required to export a particular model implementation.""" |
|
0 commit comments