-
Notifications
You must be signed in to change notification settings - Fork 29
Expand file tree
/
Copy pathlr_scheduler.py
More file actions
30 lines (22 loc) · 1.08 KB
/
lr_scheduler.py
File metadata and controls
30 lines (22 loc) · 1.08 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
import torch
from theconf import Config as C
def adjust_learning_rate_pyramid(optimizer, max_epoch):
def __adjust_learning_rate_pyramid(epoch):
"""Sets the learning rate to the initial LR decayed by 10 every 30 epochs"""
base_lr = C.get()['lr']
lr = base_lr * (0.1 ** (epoch // (max_epoch * 0.5))) * (0.1 ** (epoch // (max_epoch * 0.75)))
return lr
return torch.optim.lr_scheduler.LambdaLR(optimizer, __adjust_learning_rate_pyramid)
def adjust_learning_rate_resnet(optimizer):
"""
Sets the learning rate to the initial LR decayed by 10 on every predefined epochs
Ref: AutoAugment
"""
if C.get()['epoch'] == 90:
return torch.optim.lr_scheduler.MultiStepLR(optimizer, [30, 60, 80])
elif C.get()['epoch'] == 270: # autoaugment
return torch.optim.lr_scheduler.MultiStepLR(optimizer, [90, 180, 240])
elif C.get()['epoch'] == 300: # autoaugment
return torch.optim.lr_scheduler.MultiStepLR(optimizer, [75, 150, 225])
else:
raise ValueError('invalid epoch=%d for resnet scheduler' % C.get()['epoch'])