-
Notifications
You must be signed in to change notification settings - Fork 477
Open
Labels
bug / fixSomething isn't workingSomething isn't workinghelp wantedExtra attention is neededExtra attention is needed
Description
🐛 Bug
Not really a bug, but a limitation that can be solved easily.
The current implementation for the MulticlassAccuracy uses quadratic memory relative to the number of classes.
In my specific case i tried to use it for hundreds of thousands of classes, and got CUDA OOM.
From looking at the code, looks like the full confusion matrix was created just for easier and cleaner code, but its easy to use only linear memory.
Thanks in advance for the support ❤️
To Reproduce
import torch
from torchmetrics.classification import MulticlassAccuracy
metric = MulticlassAccuracy(average='macro', num_classes=1_000_000)
metric.update(torch.tensor([1]), torch.tensor([1])) # <-- OOM hereAdditional context
Fix:
# _multiclass_stat_scores_update
- unique_mapping = target.to(torch.long) * num_classes + preds.to(torch.long)
- bins = _bincount(unique_mapping, minlength=num_classes**2)
- confmat = bins.reshape(num_classes, num_classes)
- tp = confmat.diag()
- fp = confmat.sum(0) - tp
- fn = confmat.sum(1) - tp
- tn = confmat.sum() - (fp + fn + tp)
+ tp = _bincount(preds[target == preds], minlength=num_classes)
+ fp = _bincount(preds, minlength=num_classes) - tp
+ fn = _bincount(target, minlength=num_classes) - tp
+ tn = target.numel() - (tp + fp + fn)Reactions are currently unavailable
Metadata
Metadata
Assignees
Labels
bug / fixSomething isn't workingSomething isn't workinghelp wantedExtra attention is neededExtra attention is needed