We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
There was an error while loading. Please reload this page.
1 parent 98a9cac commit d8e6454Copy full SHA for d8e6454
src/otx/core/model/segmentation.py
@@ -219,7 +219,8 @@ def _dispatch_label_info(label_info: LabelInfoTypes) -> LabelInfo:
219
220
def forward_for_tracing(self, image: Tensor) -> Tensor | dict[str, Tensor]:
221
"""Model forward function used for the model tracing during model exportation."""
222
- return self.model(inputs=image, mode="tensor")
+ raw_outputs = self.model(inputs=image, mode="tensor")
223
+ return torch.softmax(raw_outputs, dim=1)
224
225
def get_dummy_input(self, batch_size: int = 1) -> SegBatchDataEntity:
226
"""Returns a dummy input for semantic segmentation model."""
0 commit comments