Skip to content

Commit cba6ed0

Browse files
committed
增加多种优化器
1 parent 17aa4eb commit cba6ed0

File tree

1 file changed

+9
-1
lines changed

1 file changed

+9
-1
lines changed

CIFAR10_code/train.py

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,8 @@
2727
parser.add_argument('--epochs', type = int, default=20, help = 'Epochs')
2828
parser.add_argument('--resume', '-r', action='store_true', help='resume from checkpoint')
2929
parser.add_argument('--patience', '-p', type = int, default=7, help='patience for Early stop')
30+
parser.add_argument('--optim','-o',type = str, choices = ['sgd','adam','adamw'],help = 'choose optimizer')
31+
3032
args = parser.parse_args()
3133

3234
print(args)
@@ -86,7 +88,13 @@
8688

8789
early_stopping = EarlyStopping(patience = args.patience, verbose=True)
8890
criterion = nn.CrossEntropyLoss()
89-
optimizer = optim.AdamW(net.parameters(), lr=args.lr)
91+
if args.optim == 'adamw':
92+
optimizer = optim.AdamW(net.parameters(), lr=args.lr)
93+
elif args.optim == 'adam':
94+
optimizer = optim.Adam(net.parameters(), lr=args.lr)
95+
else:
96+
optimizer = optim.SGD(net.parameters(), lr=args.lr,
97+
momentum=0.9, weight_decay=5e-4)
9098
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, 'min', factor=0.94,verbose=True,patience = 1,min_lr = 0.000001) # 动态更新学习率
9199

92100
epochs = args.epochs

0 commit comments

Comments
 (0)