We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
There was an error while loading. Please reload this page.
1 parent f944242 commit 4a3df78Copy full SHA for 4a3df78
timm/utils/metrics.py
@@ -28,5 +28,5 @@ def accuracy(output, target, topk=(1,)):
28
batch_size = target.size(0)
29
_, pred = output.topk(maxk, 1, True, True)
30
pred = pred.t()
31
- correct = pred.eq(target.view(1, -1).expand_as(pred))
32
- return [correct[:k].view(-1).float().sum(0) * 100. / batch_size for k in topk]
+ correct = pred.eq(target.reshape(1, -1).expand_as(pred))
+ return [correct[:k].reshape(-1).float().sum(0) * 100. / batch_size for k in topk]
0 commit comments