Skip to content

Commit 2bcf1b2

Browse files
authored
Fix classification rt_info (#3922)
* Restore output_raw_scores for classificaiton * Add uts * Fix linter
1 parent 51d1adf commit 2bcf1b2

File tree

4 files changed

+13
-0
lines changed

4 files changed

+13
-0
lines changed

src/otx/core/model/classification.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -154,6 +154,7 @@ def _export_parameters(self) -> TaskLevelExportParameters:
154154
task_type="classification",
155155
multilabel=False,
156156
hierarchical=False,
157+
output_raw_scores=True,
157158
)
158159

159160
@property
@@ -279,6 +280,7 @@ def _export_parameters(self) -> TaskLevelExportParameters:
279280
multilabel=True,
280281
hierarchical=False,
281282
confidence_threshold=0.5,
283+
output_raw_scores=True,
282284
)
283285

284286
@property
@@ -401,6 +403,7 @@ def _export_parameters(self) -> TaskLevelExportParameters:
401403
multilabel=False,
402404
hierarchical=True,
403405
confidence_threshold=0.5,
406+
output_raw_scores=True,
404407
)
405408

406409
@property

src/otx/core/types/export.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,8 @@ class TaskLevelExportParameters:
3434
Only specified for the classification task.
3535
hierarchical (bool | None): Whether it is hierarchical or not.
3636
Only specified for the classification task.
37+
output_raw_scores (bool | None): Whether to output raw scores.
38+
Only specified for the classification task.
3739
confidence_threshold (float | None): Confidence threshold for model prediction probability.
3840
It is used only for classification tasks, detection and instance segmentation tasks.
3941
iou_threshold (float | None): The Intersection over Union (IoU) threshold
@@ -60,6 +62,7 @@ class TaskLevelExportParameters:
6062
# (Optional) Classification tasks
6163
multilabel: bool | None = None
6264
hierarchical: bool | None = None
65+
output_raw_scores: bool | None = None
6366

6467
# (Optional) Classification tasks, detection and instance segmentation task
6568
confidence_threshold: float | None = None
@@ -133,6 +136,9 @@ def to_metadata(self) -> dict[tuple[str, str], str]:
133136
if self.hierarchical is not None:
134137
metadata[("model_info", "hierarchical")] = str(self.hierarchical)
135138

139+
if self.output_raw_scores is not None:
140+
metadata[("model_info", "output_raw_scores")] = str(self.output_raw_scores)
141+
136142
if self.confidence_threshold is not None:
137143
metadata[("model_info", "confidence_threshold")] = str(self.confidence_threshold)
138144

tests/unit/core/model/test_classification.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,7 @@ def test_export_parameters(
4848
assert model._export_parameters.task_type.lower() == "classification"
4949
assert not model._export_parameters.multilabel
5050
assert not model._export_parameters.hierarchical
51+
assert model._export_parameters.output_raw_scores
5152

5253
model = OTXMultilabelClsModel(
5354
label_info=1,

tests/unit/core/types/test_export.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@ def test_wrap(fxt_label_info, task_type):
1717

1818
multilabel = False
1919
hierarchical = False
20+
output_raw_scores = True
2021
confidence_threshold = 0.0
2122
iou_threshold = 0.0
2223
return_soft_prediction = False
@@ -27,6 +28,7 @@ def test_wrap(fxt_label_info, task_type):
2728
params = params.wrap(
2829
multilabel=multilabel,
2930
hierarchical=hierarchical,
31+
output_raw_scores=output_raw_scores,
3032
confidence_threshold=confidence_threshold,
3133
iou_threshold=iou_threshold,
3234
return_soft_prediction=return_soft_prediction,
@@ -44,6 +46,7 @@ def test_wrap(fxt_label_info, task_type):
4446
assert metadata[("model_info", "return_soft_prediction")] == str(return_soft_prediction)
4547
assert metadata[("model_info", "soft_threshold")] == str(soft_threshold)
4648
assert metadata[("model_info", "blur_strength")] == str(blur_strength)
49+
assert metadata[("model_info", "output_raw_scores")] == str(output_raw_scores)
4750

4851
# Tile config
4952
assert ("model_info", "tile_size") in metadata

0 commit comments

Comments
 (0)