Skip to content

Commit 929132d

Browse files
authored
Fix binary classification metric task (#3928)
* Fix binary classification * Add unit-tests
1 parent 112b2b2 commit 929132d

File tree

2 files changed

+15
-1
lines changed

2 files changed

+15
-1
lines changed

src/otx/core/metrics/accuracy.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -346,8 +346,10 @@ def compute(self) -> torch.Tensor:
346346

347347

348348
def _multi_class_cls_metric_callable(label_info: LabelInfo) -> MetricCollection:
349+
num_classes = label_info.num_classes
350+
task = "binary" if num_classes == 1 else "multiclass"
349351
return MetricCollection(
350-
{"accuracy": TorchmetricAcc(task="multiclass", num_classes=label_info.num_classes)},
352+
{"accuracy": TorchmetricAcc(task=task, num_classes=num_classes)},
351353
)
352354

353355

tests/unit/core/metrics/test_accuracy.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,9 +9,11 @@
99
HlabelAccuracy,
1010
MixedHLabelAccuracy,
1111
MulticlassAccuracywithLabelGroup,
12+
MultiClassClsMetricCallable,
1213
MultilabelAccuracywithLabelGroup,
1314
)
1415
from otx.core.types.label import HLabelInfo, LabelInfo
16+
from torchmetrics.classification.accuracy import BinaryAccuracy, MulticlassAccuracy
1517

1618

1719
class TestAccuracy:
@@ -45,6 +47,16 @@ def test_multiclass_accuracy(self, fxt_multiclass_labelinfo: LabelInfo) -> None:
4547
acc = result["accuracy"]
4648
assert round(acc.item(), 3) == 0.792
4749

50+
def test_default_multi_class_cls_metric_callable(self, fxt_multiclass_labelinfo: LabelInfo) -> None:
51+
assert fxt_multiclass_labelinfo.num_classes > 1
52+
metric = MultiClassClsMetricCallable(fxt_multiclass_labelinfo)
53+
assert isinstance(metric.accuracy, MulticlassAccuracy)
54+
55+
one_class_label_info = LabelInfo(label_names=["class1"], label_groups=[["class1"]])
56+
assert one_class_label_info.num_classes == 1
57+
binary_metric = MultiClassClsMetricCallable(one_class_label_info)
58+
assert isinstance(binary_metric.accuracy, BinaryAccuracy)
59+
4860
def test_multilabel_accuracy(self, fxt_multilabel_labelinfo: LabelInfo) -> None:
4961
"""Check whether accuracy is same with OTX1.x version."""
5062
preds = [

0 commit comments

Comments
 (0)