File tree Expand file tree Collapse file tree 2 files changed +3
-2
lines changed
Expand file tree Collapse file tree 2 files changed +3
-2
lines changed Original file line number Diff line number Diff line change @@ -106,7 +106,8 @@ def eval(args):
106106 torch .set_grad_enabled (False )
107107 device = torch .device ('cuda' if torch .cuda .is_available () else 'cpu' )
108108 model = get_model (args .model , args .dataset )
109- model_state = torch .load (args .checkpoint )['model_state' ]
109+ model_state = torch .load (args .checkpoint ,
110+ map_location = device )['model_state' ]
110111 model .load_state_dict (model_state )
111112 model .to (device )
112113 model .eval ()
Original file line number Diff line number Diff line change @@ -62,7 +62,7 @@ def train(args):
6262 if os .path .isfile (args .resume ):
6363 print ("Loading model and optimizer from checkpoint '{}'" .format \
6464 (args .resume ))
65- checkpoint = torch .load (args .resume )
65+ checkpoint = torch .load (args .resume , map_location = device )
6666 model .load_state_dict (checkpoint ['model_state' ])
6767 optimizer .load_state_dict (checkpoint ['optimizer_state' ])
6868 print ("Loaded checkpoint '{}' (epoch{})" .format (args .resume ,
You can’t perform that action at this time.
0 commit comments