Skip to content

Commit a6603e1

Browse files
committed
修改部分参数
1 parent 3ad009f commit a6603e1

File tree

1 file changed

+3
-2
lines changed

1 file changed

+3
-2
lines changed

CIFAR10_code/train.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -68,10 +68,11 @@
6868
from nets.MobileNetv2 import MobileNetV2
6969
net = MobileNetV2()
7070

71-
if args.cuda and torch.cuda.is_available():
71+
if args.cuda:
7272
device = 'cuda'
7373
net = torch.nn.DataParallel(net)
74-
cudnn.benchmark = True
74+
# 当计算图不会改变的时候(每次输入形状相同,模型不改变)的情况下可以提高性能,反之则降低性能
75+
torch.backends.cudnn.benchmark = True
7576
else:
7677
device = 'cpu'
7778

0 commit comments

Comments
 (0)