-
Notifications
You must be signed in to change notification settings - Fork 467
Closed
Closed
Copy link
Labels
bug / fixSomething isn't workingSomething isn't workinghelp wantedExtra attention is neededExtra attention is neededv1.2.x
Description
🐛 Bug
torchmetrics seems unable to handle a metric that has been serialized on a CUDA-enabled installation when later loading on a CPU-only installation and calling metric.to('cpu') (see below). This even happens with torch.load(..., map_location='cpu').
# On CUDA-enabled torch
import torch
import torchmetrics
metric = torchmetrics.classification.Accuracy(task="multiclass", num_classes=5).cuda()
print(metric._device)
# device(type='cuda', index=0)
torch.save(metric, 'test.pth')
# On CPU-only torch
m = torch.load('test.pth', map_location='cpu')
print(m._device)
# device(type='cuda', index=0)
m.to('cpu')
#[...]
# raise AssertionError("Torch not compiled with CUDA enabled")
#AssertionError: Torch not compiled with CUDA enabledThe problematic code is here:
torchmetrics/src/torchmetrics/metric.py
Lines 811 to 813 in 894de4c
| # make sure to update the device attribute | |
| # if the dummy tensor moves device by fn function we should also update the attribute | |
| self._device = fn(torch.zeros(1, device=self.device)).device |
After loading, self.device will still refer to the original CUDA device as that is what was serialized:
>>> m = torch.load('test.pth', map_location='cpu')
>>> m._device
device(type='cuda', index=0)Constructing a tensor with torch.zeros(1, device=self.device) will then error out if CUDA is not available.
To Reproduce
see above
Expected behavior
Calling metric.to('cpu') on a metric stored on the CPU device should not throw an error.
Environment
torchmetrics==1.2.0
Additional context
Originally reported here:
Metadata
Metadata
Assignees
Labels
bug / fixSomething isn't workingSomething isn't workinghelp wantedExtra attention is neededExtra attention is neededv1.2.x