-
Notifications
You must be signed in to change notification settings - Fork 464
Closed
Closed
Copy link
Labels
bug / fixSomething isn't workingSomething isn't workinghelp wantedExtra attention is neededExtra attention is neededv1.2.x
Description
🐛 Bug
torchmetrics
seems unable to handle a metric that has been serialized on a CUDA-enabled installation when later loading on a CPU-only installation and calling metric.to('cpu')
(see below). This even happens with torch.load(..., map_location='cpu')
.
# On CUDA-enabled torch
import torch
import torchmetrics
metric = torchmetrics.classification.Accuracy(task="multiclass", num_classes=5).cuda()
print(metric._device)
# device(type='cuda', index=0)
torch.save(metric, 'test.pth')
# On CPU-only torch
m = torch.load('test.pth', map_location='cpu')
print(m._device)
# device(type='cuda', index=0)
m.to('cpu')
#[...]
# raise AssertionError("Torch not compiled with CUDA enabled")
#AssertionError: Torch not compiled with CUDA enabled
The problematic code is here:
torchmetrics/src/torchmetrics/metric.py
Lines 811 to 813 in 894de4c
# make sure to update the device attribute | |
# if the dummy tensor moves device by fn function we should also update the attribute | |
self._device = fn(torch.zeros(1, device=self.device)).device |
After loading, self.device
will still refer to the original CUDA device as that is what was serialized:
>>> m = torch.load('test.pth', map_location='cpu')
>>> m._device
device(type='cuda', index=0)
Constructing a tensor with torch.zeros(1, device=self.device)
will then error out if CUDA is not available.
To Reproduce
see above
Expected behavior
Calling metric.to('cpu')
on a metric stored on the CPU device should not throw an error.
Environment
torchmetrics==1.2.0
Additional context
Originally reported here:
Metadata
Metadata
Assignees
Labels
bug / fixSomething isn't workingSomething isn't workinghelp wantedExtra attention is neededExtra attention is neededv1.2.x