How to use metrics with multi labels #526
-
|
Hi guys, I hope you are doing great! I was wondering how to correctly use a metric, e.g. Accuracy, with multi labels. For example num_classes = 4
pred = torch.tensor([0, 0, 0, 0])
target = torch.tensor([0, 1, 0, 1])
acc = M.Accuracy(num_classes=num_classes)
print(acc(pred, target))Where Thanks! |
Beta Was this translation helpful? Give feedback.
Replies: 1 comment
-
|
For multi-labels we assume that you are providing data as pred = torch.tensor([0, 0, 0, 0]).unsqueeze(0)
target = torch.tensor([0, 1, 0, 1]).unsqueeze(0) |
Beta Was this translation helpful? Give feedback.
For multi-labels we assume that you are providing data as
[B,L]whereBis the batch size andLis the number of labels (num_classesin your code). You are therefore missing the batch dimension, which is 1 in your case. A simpleunsqueezeshould be enough: