File tree Expand file tree Collapse file tree 1 file changed +11
-5
lines changed
Expand file tree Collapse file tree 1 file changed +11
-5
lines changed Original file line number Diff line number Diff line change @@ -163,14 +163,20 @@ def main():
163163 network_data = None
164164 print ("=> creating model '{}'" .format (args .arch ))
165165
166- model = models .__dict__ [args .arch ](network_data ).cuda ()
167- model = torch .nn .DataParallel (model ).cuda ()
168- cudnn .benchmark = True
166+ if device .type == "cuda" :
167+ model = models .__dict__ [args .arch ](network_data ).cuda ()
168+ else :
169+ model = models .__dict__ [args .arch ](network_data ).cpu ()
169170
170171 assert (args .solver in ['adam' , 'sgd' ])
171172 print ('=> setting {} solver' .format (args .solver ))
172- param_groups = [{'params' : model .module .bias_parameters (), 'weight_decay' : args .bias_decay },
173- {'params' : model .module .weight_parameters (), 'weight_decay' : args .weight_decay }]
173+ param_groups = [{'params' : model .bias_parameters (), 'weight_decay' : args .bias_decay },
174+ {'params' : model .weight_parameters (), 'weight_decay' : args .weight_decay }]
175+
176+ if device .type == "cuda" :
177+ model = torch .nn .DataParallel (model ).cuda ()
178+ cudnn .benchmark = True
179+
174180 if args .solver == 'adam' :
175181 optimizer = torch .optim .Adam (param_groups , args .lr ,
176182 betas = (args .momentum , args .beta ))
You can’t perform that action at this time.
0 commit comments