Skip to content

Commit 0754cee

Browse files
authored
refactor: load model
1 parent d994913 commit 0754cee

File tree

1 file changed

+2
-2
lines changed

1 file changed

+2
-2
lines changed

src/aind_exaspim_image_compression/inference.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -264,8 +264,8 @@ def load_model(path, device="cuda"):
264264
UNet model loaded with weights and set to evaluation mode.
265265
"""
266266
model = UNet()
267-
model.load_state_dict(torch.load(path))
268-
model.eval().to(device)
267+
model.load_state_dict(torch.load(path, map_location=device))
268+
model.eval()
269269
return model
270270

271271

0 commit comments

Comments
 (0)