We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
1 parent d994913 commit 0754ceeCopy full SHA for 0754cee
src/aind_exaspim_image_compression/inference.py
@@ -264,8 +264,8 @@ def load_model(path, device="cuda"):
264
UNet model loaded with weights and set to evaluation mode.
265
"""
266
model = UNet()
267
- model.load_state_dict(torch.load(path))
268
- model.eval().to(device)
+ model.load_state_dict(torch.load(path, map_location=device))
+ model.eval()
269
return model
270
271
0 commit comments