Skip to content

Commit 40e64a9

Browse files
committed
Remove GPU hard-coding in main.py
If torch does not detect any CUDA device, it should fallback on CPU
1 parent 53248d1 commit 40e64a9

File tree

1 file changed

+11
-5
lines changed

1 file changed

+11
-5
lines changed

main.py

Lines changed: 11 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -163,14 +163,20 @@ def main():
163163
network_data = None
164164
print("=> creating model '{}'".format(args.arch))
165165

166-
model = models.__dict__[args.arch](network_data).cuda()
167-
model = torch.nn.DataParallel(model).cuda()
168-
cudnn.benchmark = True
166+
if device.type == "cuda":
167+
model = models.__dict__[args.arch](network_data).cuda()
168+
else:
169+
model = models.__dict__[args.arch](network_data).cpu()
169170

170171
assert(args.solver in ['adam', 'sgd'])
171172
print('=> setting {} solver'.format(args.solver))
172-
param_groups = [{'params': model.module.bias_parameters(), 'weight_decay': args.bias_decay},
173-
{'params': model.module.weight_parameters(), 'weight_decay': args.weight_decay}]
173+
param_groups = [{'params': model.bias_parameters(), 'weight_decay': args.bias_decay},
174+
{'params': model.weight_parameters(), 'weight_decay': args.weight_decay}]
175+
176+
if device.type == "cuda":
177+
model = torch.nn.DataParallel(model).cuda()
178+
cudnn.benchmark = True
179+
174180
if args.solver == 'adam':
175181
optimizer = torch.optim.Adam(param_groups, args.lr,
176182
betas=(args.momentum, args.beta))

0 commit comments

Comments
 (0)