@@ -288,12 +288,17 @@ def __init__(
288288 ]
289289
290290 # Multilabel classification accuracy metrics
291- if self .num_multilabel_classes > 0 :
291+ # https://github.com/Lightning-AI/torchmetrics/blob/6377aa5b6fe2863761839e6b8b5a857ef1b8acfa/src/torchmetrics/functional/classification/stat_scores.py#L583-L584
292+ # MultilabelAccuracy is available when num_multilabel_classes is greater than 2.
293+ self .multilabel_accuracy = None
294+ if self .num_multilabel_classes > 1 :
292295 self .multilabel_accuracy = TorchmetricMultilabelAcc (
293296 num_labels = self .num_multilabel_classes ,
294297 threshold = 0.5 ,
295298 average = "macro" ,
296299 )
300+ elif self .num_multilabel_classes == 1 :
301+ self .multilabel_accuracy = TorchmetricAcc (task = "binary" , num_classes = self .num_multilabel_classes )
297302
298303 def _apply (self , fn : Callable , exclude_state : Sequence [str ] = "" ) -> nn .Module :
299304 self .multiclass_head_accuracy = [
@@ -303,7 +308,7 @@ def _apply(self, fn: Callable, exclude_state: Sequence[str] = "") -> nn.Module:
303308 )
304309 for acc in self .multiclass_head_accuracy
305310 ]
306- if self .num_multilabel_classes > 0 :
311+ if self .multilabel_accuracy is not None :
307312 self .multilabel_accuracy = self .multilabel_accuracy ._apply (fn , exclude_state ) # noqa: SLF001
308313 return self
309314
@@ -322,7 +327,7 @@ def update(self, preds: torch.Tensor, target: torch.Tensor) -> None:
322327 target_multiclass [multiclass_mask ],
323328 )
324329
325- if self .num_multilabel_classes > 0 :
330+ if self .multilabel_accuracy is not None :
326331 # Split preds into multiclass and multilabel parts
327332 preds_multilabel = preds [:, self .num_multiclass_heads :]
328333 target_multilabel = target [:, self .num_multiclass_heads :]
@@ -337,7 +342,7 @@ def compute(self) -> torch.Tensor:
337342 ),
338343 )
339344
340- if self .num_multilabel_classes > 0 :
345+ if self .multilabel_accuracy is not None :
341346 multilabel_acc = self .multilabel_accuracy .compute ()
342347
343348 return (multiclass_accs + multilabel_acc ) / 2
0 commit comments