Skip to content

Commit d8e6454

Browse files
authored
Fix soft predictions for Semantic Segmentation (#3934)
fix soft preds
1 parent 98a9cac commit d8e6454

File tree

1 file changed

+2
-1
lines changed

1 file changed

+2
-1
lines changed

src/otx/core/model/segmentation.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -219,7 +219,8 @@ def _dispatch_label_info(label_info: LabelInfoTypes) -> LabelInfo:
219219

220220
def forward_for_tracing(self, image: Tensor) -> Tensor | dict[str, Tensor]:
221221
"""Model forward function used for the model tracing during model exportation."""
222-
return self.model(inputs=image, mode="tensor")
222+
raw_outputs = self.model(inputs=image, mode="tensor")
223+
return torch.softmax(raw_outputs, dim=1)
223224

224225
def get_dummy_input(self, batch_size: int = 1) -> SegBatchDataEntity:
225226
"""Returns a dummy input for semantic segmentation model."""

0 commit comments

Comments
 (0)