Skip to content

Multiclass accuracy requires num_classes**2 memory #3343

@barakugav

Description

@barakugav

🐛 Bug

Not really a bug, but a limitation that can be solved easily.
The current implementation for the MulticlassAccuracy uses quadratic memory relative to the number of classes.
In my specific case i tried to use it for hundreds of thousands of classes, and got CUDA OOM.
From looking at the code, looks like the full confusion matrix was created just for easier and cleaner code, but its easy to use only linear memory.

Thanks in advance for the support ❤️

To Reproduce

import torch
from torchmetrics.classification import MulticlassAccuracy

metric = MulticlassAccuracy(average='macro', num_classes=1_000_000)
metric.update(torch.tensor([1]), torch.tensor([1])) # <-- OOM here

Additional context

Fix:

# _multiclass_stat_scores_update
- unique_mapping = target.to(torch.long) * num_classes + preds.to(torch.long)
- bins = _bincount(unique_mapping, minlength=num_classes**2)
- confmat = bins.reshape(num_classes, num_classes)
- tp = confmat.diag()
- fp = confmat.sum(0) - tp
- fn = confmat.sum(1) - tp
- tn = confmat.sum() - (fp + fn + tp)
+ tp = _bincount(preds[target == preds], minlength=num_classes)
+ fp = _bincount(preds, minlength=num_classes) - tp
+ fn = _bincount(target, minlength=num_classes) - tp
+ tn = target.numel() - (tp + fp + fn)

#3342

Metadata

Metadata

Assignees

No one assigned

    Labels

    bug / fixSomething isn't workinghelp wantedExtra attention is needed

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions