Skip to content

Commit 543ff9f

Browse files
committed
Bug fix
1 parent 2d0c3a5 commit 543ff9f

File tree

2 files changed

+3
-2
lines changed

2 files changed

+3
-2
lines changed

eval.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -106,7 +106,8 @@ def eval(args):
106106
torch.set_grad_enabled(False)
107107
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
108108
model = get_model(args.model, args.dataset)
109-
model_state = torch.load(args.checkpoint)['model_state']
109+
model_state = torch.load(args.checkpoint,
110+
map_location=device)['model_state']
110111
model.load_state_dict(model_state)
111112
model.to(device)
112113
model.eval()

train.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -62,7 +62,7 @@ def train(args):
6262
if os.path.isfile(args.resume):
6363
print("Loading model and optimizer from checkpoint '{}'".format\
6464
(args.resume))
65-
checkpoint = torch.load(args.resume)
65+
checkpoint = torch.load(args.resume, map_location=device)
6666
model.load_state_dict(checkpoint['model_state'])
6767
optimizer.load_state_dict(checkpoint['optimizer_state'])
6868
print("Loaded checkpoint '{}' (epoch{})".format(args.resume,

0 commit comments

Comments
 (0)