Skip to content

Commit 08de045

Browse files
authored
Fix multilabel_accuracy of MixedHLabelAccuracy (#4042)
* Fix metric for multi-label * Fix1 * Add CHANGELOG
1 parent 93adf0d commit 08de045

File tree

3 files changed

+37
-5
lines changed

3 files changed

+37
-5
lines changed

CHANGELOG.md

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -86,6 +86,8 @@ All notable changes to this project will be documented in this file.
8686
(<https://github.com/openvinotoolkit/training_extensions/pull/4014>)
8787
- Fix out_features in HierarchicalCBAMClsHead
8888
(<https://github.com/openvinotoolkit/training_extensions/pull/4016>)
89+
- Fix multilabel_accuracy of MixedHLabelAccuracy
90+
(<https://github.com/openvinotoolkit/training_extensions/pull/4042>)
8991

9092
## \[v2.1.0\]
9193

src/otx/core/metrics/accuracy.py

Lines changed: 9 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -288,12 +288,17 @@ def __init__(
288288
]
289289

290290
# Multilabel classification accuracy metrics
291-
if self.num_multilabel_classes > 0:
291+
# https://github.com/Lightning-AI/torchmetrics/blob/6377aa5b6fe2863761839e6b8b5a857ef1b8acfa/src/torchmetrics/functional/classification/stat_scores.py#L583-L584
292+
# MultilabelAccuracy is available when num_multilabel_classes is greater than 2.
293+
self.multilabel_accuracy = None
294+
if self.num_multilabel_classes > 1:
292295
self.multilabel_accuracy = TorchmetricMultilabelAcc(
293296
num_labels=self.num_multilabel_classes,
294297
threshold=0.5,
295298
average="macro",
296299
)
300+
elif self.num_multilabel_classes == 1:
301+
self.multilabel_accuracy = TorchmetricAcc(task="binary", num_classes=self.num_multilabel_classes)
297302

298303
def _apply(self, fn: Callable, exclude_state: Sequence[str] = "") -> nn.Module:
299304
self.multiclass_head_accuracy = [
@@ -303,7 +308,7 @@ def _apply(self, fn: Callable, exclude_state: Sequence[str] = "") -> nn.Module:
303308
)
304309
for acc in self.multiclass_head_accuracy
305310
]
306-
if self.num_multilabel_classes > 0:
311+
if self.multilabel_accuracy is not None:
307312
self.multilabel_accuracy = self.multilabel_accuracy._apply(fn, exclude_state) # noqa: SLF001
308313
return self
309314

@@ -322,7 +327,7 @@ def update(self, preds: torch.Tensor, target: torch.Tensor) -> None:
322327
target_multiclass[multiclass_mask],
323328
)
324329

325-
if self.num_multilabel_classes > 0:
330+
if self.multilabel_accuracy is not None:
326331
# Split preds into multiclass and multilabel parts
327332
preds_multilabel = preds[:, self.num_multiclass_heads :]
328333
target_multilabel = target[:, self.num_multiclass_heads :]
@@ -337,7 +342,7 @@ def compute(self) -> torch.Tensor:
337342
),
338343
)
339344

340-
if self.num_multilabel_classes > 0:
345+
if self.multilabel_accuracy is not None:
341346
multilabel_acc = self.multilabel_accuracy.compute()
342347

343348
return (multiclass_accs + multilabel_acc) / 2

tests/unit/core/metrics/test_accuracy.py

Lines changed: 26 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313
MultilabelAccuracywithLabelGroup,
1414
)
1515
from otx.core.types.label import HLabelInfo, LabelInfo
16-
from torchmetrics.classification.accuracy import BinaryAccuracy, MulticlassAccuracy
16+
from torchmetrics.classification.accuracy import BinaryAccuracy, MulticlassAccuracy, MultilabelAccuracy
1717

1818

1919
class TestAccuracy:
@@ -120,3 +120,28 @@ def test_multilabel_only(self) -> None:
120120
head_logits_info={"head1": (0, 5), "head2": (5, 10)},
121121
threshold_multilabel=0.5,
122122
)
123+
124+
def test_multilabel_accuracy(self, hlabel_accuracy) -> None:
125+
# Normal Case: num_multilabel_classes > 1 -> MultilabelAccuracy
126+
assert hlabel_accuracy.num_multilabel_classes == 3
127+
assert isinstance(hlabel_accuracy.multilabel_accuracy, MultilabelAccuracy)
128+
129+
# Edge Case: num_multilabel_classes = 1 -> BinaryAccuracy
130+
acc = MixedHLabelAccuracy(
131+
num_multiclass_heads=2,
132+
num_multilabel_classes=1,
133+
head_logits_info={"head1": (0, 5), "head2": (5, 10)},
134+
threshold_multilabel=0.5,
135+
)
136+
assert acc.num_multilabel_classes == 1
137+
assert isinstance(acc.multilabel_accuracy, BinaryAccuracy)
138+
139+
# None Case: num_multilabel_classes = 0 -> None
140+
acc = MixedHLabelAccuracy(
141+
num_multiclass_heads=2,
142+
num_multilabel_classes=0,
143+
head_logits_info={"head1": (0, 5), "head2": (5, 10)},
144+
threshold_multilabel=0.5,
145+
)
146+
assert acc.num_multilabel_classes == 0
147+
assert acc.multilabel_accuracy is None

0 commit comments

Comments
 (0)