-
Notifications
You must be signed in to change notification settings - Fork 56
Open
Description
π Describe the bug
This works:
from torcheval.metrics import BinaryAUROC
import torch
auroc = BinaryAUROC()
input = torch.tensor([0.1, 0.2, 0.3, 0.4], device='cuda')
target = torch.tensor([0, 0, 1, 1], device='cuda')
auroc.update(input, target)
auroc.compute()This does not (identical except for the added weight):
auroc = BinaryAUROC()
input = torch.tensor([0.1, 0.2, 0.3, 0.4], device='cuda')
target = torch.tensor([0, 0, 1, 1], device='cuda')
weight = torch.tensor([0.1, 0.2, 0.3, 0.4], device='cuda')
auroc.update(input, target, weight=weight)
auroc.compute()"RuntimeError: Expected all tensors to be on the same device, but found at least two devices, cuda:0 and cpu!"
I believe the issue is that BinaryAUROC.update() (
| def update( |
has input = input.to(self.device) and target = target.to(self.device) but not weight = weight .to(self.device). So I believe this should be a very simple 1-2 line fix. (I did not investigate whether other metrics might have the same issue.)
Versions
irrelevant
Reactions are currently unavailable
Metadata
Metadata
Assignees
Labels
No labels