About torchmetrics.classification.MulticlassAccuracy: why torch.long is required when 'multidim_average' set to 'samplewise'? #1969
Unanswered
mifan002
asked this question in
Classification
Replies: 0 comments
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Uh oh!
There was an error while loading. Please reload this page.
-
Hi all,
I'm using torchmetrics.classification.MulticlassAccuracy to calculate the pixel accuracy per class for semantic segmentation. When I set the 'multidim_average' to defalt('global'), then it's totally OK whichever dtype I chose between torch.uint8 or torch.long for the input data (preds and target) of "MulticlassAccuracy.forward()". However, when I set 'multidim_average' to 'samplewise', then the dtype of the input data HAS to be torch.long, otherwise I got the error:
Traceback (most recent call last):
File "C:\Users\fanmi\anaconda3\envs\env_thesis\lib\site-packages\torchmetrics\metric.py", line 405, in wrapped_func
raise err
File "C:\Users\fanmi\anaconda3\envs\env_thesis\lib\site-packages\torchmetrics\metric.py", line 395, in wrapped_func
update(*args, **kwargs)
File "C:\Users\fanmi\anaconda3\envs\env_thesis\lib\site-packages\torchmetrics\classification\stat_scores.py", line 317, in update
tp, fp, tn, fn = _multiclass_stat_scores_update(
File "C:\Users\fanmi\anaconda3\envs\env_thesis\lib\site-packages\torchmetrics\functional\classification\stat_scores.py", line 378, in _multiclass_stat_scores_update
preds_oh = torch.nn.functional.one_hot(
RuntimeError: one_hot is only applicable to index tensor.
python-BaseException
Why is this the case? And how should I understand the error message here? It seems irrelevant to what's actually happening here...
Version Info:
torchmetrics==0.10.1
torch==1.12.1
Code snippet for issue reproduction:
from torchmetrics.classification import MulticlassAccuracy
metric = MulticlassAccuracy(num_classes=4, average="none", multidim_average='samplewise', ignore_index=0)
prediction = torch.randint(low=0, high=4, size=(1,224,224)).to(torch.uint8)
label = torch.randint(low=0, high=4, size=(1,224,224)).to(torch.uint8)
score = metric(preds=prediction, target=label)
Beta Was this translation helpful? Give feedback.
All reactions