Skip to content

Commit 358ff43

Browse files
authored
Merge pull request #52 from kozistr/refactor/optimizers
[Refactor] BaseOptimizer
2 parents 8439f15 + f476b7c commit 358ff43

19 files changed

+279
-224
lines changed

pytorch_optimizer/adabelief.py

Lines changed: 13 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -3,10 +3,11 @@
33
import torch
44
from torch.optim.optimizer import Optimizer
55

6+
from pytorch_optimizer.base_optimizer import BaseOptimizer
67
from pytorch_optimizer.types import BETAS, CLOSURE, DEFAULTS, LOSS, PARAMETERS, STATE
78

89

9-
class AdaBelief(Optimizer):
10+
class AdaBelief(Optimizer, BaseOptimizer):
1011
"""
1112
Reference : https://github.com/juntang-zhuang/Adabelief-Optimizer
1213
Example :
@@ -37,7 +38,7 @@ def __init__(
3738
adamd_debias_term: bool = False,
3839
eps: float = 1e-16,
3940
):
40-
"""AdaBelief
41+
"""AdaBelief optimizer
4142
:param params: PARAMETERS. iterable of parameters to optimize or dicts defining parameter groups
4243
:param lr: float. learning rate
4344
:param betas: BETAS. coefficients used for computing running averages of gradient and the squared hessian trace
@@ -62,7 +63,7 @@ def __init__(
6263
self.adamd_debias_term = adamd_debias_term
6364
self.eps = eps
6465

65-
self.check_valid_parameters()
66+
self.validate_parameters()
6667

6768
defaults: DEFAULTS = dict(
6869
lr=lr,
@@ -75,17 +76,11 @@ def __init__(
7576
)
7677
super().__init__(params, defaults)
7778

78-
def check_valid_parameters(self):
79-
if self.lr < 0.0:
80-
raise ValueError(f'Invalid learning rate : {self.lr}')
81-
if not 0.0 <= self.betas[0] < 1.0:
82-
raise ValueError(f'Invalid beta_0 : {self.betas[0]}')
83-
if not 0.0 <= self.betas[1] < 1.0:
84-
raise ValueError(f'Invalid beta_1 : {self.betas[1]}')
85-
if self.weight_decay < 0.0:
86-
raise ValueError(f'Invalid weight_decay : {self.weight_decay}')
87-
if self.eps < 0.0:
88-
raise ValueError(f'Invalid eps : {self.eps}')
79+
def validate_parameters(self):
80+
self.validate_learning_rate(self.lr)
81+
self.validate_betas(self.betas)
82+
self.validate_weight_decay(self.weight_decay)
83+
self.validate_epsilon(self.eps)
8984

9085
def __setstate__(self, state: STATE):
9186
super().__setstate__(state)
@@ -125,7 +120,7 @@ def step(self, closure: CLOSURE = None) -> LOSS:
125120
grad = grad.float()
126121

127122
p_fp32 = p
128-
if p.dtype in {torch.float16, torch.bfloat16}:
123+
if p.dtype in (torch.float16, torch.bfloat16):
129124
p_fp32 = p_fp32.float()
130125

131126
state = self.state[p]
@@ -158,14 +153,7 @@ def step(self, closure: CLOSURE = None) -> LOSS:
158153
exp_avg_var.mul_(beta2).addcmul_(grad_residual, grad_residual, value=1.0 - beta2)
159154

160155
if group['amsgrad']:
161-
max_exp_avg_var = state['max_exp_avg_var']
162-
163-
torch.max(
164-
max_exp_avg_var,
165-
exp_avg_var.add_(group['eps']),
166-
out=max_exp_avg_var,
167-
)
168-
156+
max_exp_avg_var = torch.max(state['max_exp_avg_var'], exp_avg_var.add_(group['eps']))
169157
de_nom = (max_exp_avg_var.sqrt() / math.sqrt(bias_correction2)).add_(group['eps'])
170158
else:
171159
de_nom = (exp_avg_var.add_(group['eps']).sqrt() / math.sqrt(bias_correction2)).add_(group['eps'])
@@ -176,7 +164,7 @@ def step(self, closure: CLOSURE = None) -> LOSS:
176164
step_size /= bias_correction1
177165
p_fp32.addcdiv_(exp_avg, de_nom, value=-step_size)
178166
else:
179-
buffered = group['buffer'][int(state['step'] % 10)]
167+
buffered = group['buffer'][state['step'] % 10]
180168
if state['step'] == buffered[0]:
181169
n_sma, step_size = buffered[1], buffered[2]
182170
else:
@@ -213,7 +201,7 @@ def step(self, closure: CLOSURE = None) -> LOSS:
213201
elif step_size > 0:
214202
p_fp32.add_(exp_avg, alpha=-step_size * group['lr'])
215203

216-
if p.dtype in {torch.float16, torch.bfloat16}:
204+
if p.dtype in (torch.float16, torch.bfloat16):
217205
p.copy_(p_fp32)
218206

219207
return loss

pytorch_optimizer/adabound.py

Lines changed: 8 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -4,10 +4,11 @@
44
import torch
55
from torch.optim.optimizer import Optimizer
66

7+
from pytorch_optimizer.base_optimizer import BaseOptimizer
78
from pytorch_optimizer.types import BETAS, CLOSURE, DEFAULTS, LOSS, PARAMETERS, STATE
89

910

10-
class AdaBound(Optimizer):
11+
class AdaBound(Optimizer, BaseOptimizer):
1112
"""
1213
Reference : https://github.com/Luolc/AdaBound
1314
Example :
@@ -57,7 +58,7 @@ def __init__(
5758
self.fixed_decay = fixed_decay
5859
self.eps = eps
5960

60-
self.check_valid_parameters()
61+
self.validate_parameters()
6162

6263
defaults: DEFAULTS = dict(
6364
lr=lr,
@@ -73,17 +74,11 @@ def __init__(
7374

7475
self.base_lrs: List[float] = [group['lr'] for group in self.param_groups]
7576

76-
def check_valid_parameters(self):
77-
if self.lr < 0.0:
78-
raise ValueError(f'Invalid learning rate : {self.lr}')
79-
if self.weight_decay < 0.0:
80-
raise ValueError(f'Invalid weight_decay : {self.weight_decay}')
81-
if not 0.0 <= self.betas[0] < 1.0:
82-
raise ValueError(f'Invalid beta_0 : {self.betas[0]}')
83-
if not 0.0 <= self.betas[1] < 1.0:
84-
raise ValueError(f'Invalid beta_1 : {self.betas[1]}')
85-
if self.eps < 0.0:
86-
raise ValueError(f'Invalid eps : {self.eps}')
77+
def validate_parameters(self):
78+
self.validate_learning_rate(self.lr)
79+
self.validate_betas(self.betas)
80+
self.validate_weight_decay(self.weight_decay)
81+
self.validate_epsilon(self.eps)
8782

8883
def __setstate__(self, state: STATE):
8984
super().__setstate__(state)

pytorch_optimizer/adahessian.py

Lines changed: 12 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -3,10 +3,11 @@
33
import torch
44
from torch.optim import Optimizer
55

6+
from pytorch_optimizer.base_optimizer import BaseOptimizer
67
from pytorch_optimizer.types import BETAS, CLOSURE, DEFAULTS, LOSS, PARAMETERS
78

89

9-
class AdaHessian(Optimizer):
10+
class AdaHessian(Optimizer, BaseOptimizer):
1011
"""
1112
Reference : https://github.com/davda54/ada-hessian
1213
Example :
@@ -59,7 +60,7 @@ def __init__(
5960
self.eps = eps
6061
self.seed = seed
6162

62-
self.check_valid_parameters()
63+
self.validate_parameters()
6364

6465
# use a separate generator that deterministically generates
6566
# the same `z`s across all GPUs in case of distributed training
@@ -79,19 +80,12 @@ def __init__(
7980
p.hess = 0.0
8081
self.state[p]['hessian_step'] = 0
8182

82-
def check_valid_parameters(self):
83-
if self.lr < 0.0:
84-
raise ValueError(f'Invalid learning rate : {self.lr}')
85-
if self.weight_decay < 0.0:
86-
raise ValueError(f'Invalid weight_decay : {self.weight_decay}')
87-
if not 0.0 <= self.betas[0] < 1.0:
88-
raise ValueError(f'Invalid beta_0 : {self.betas[0]}')
89-
if not 0.0 <= self.betas[1] < 1.0:
90-
raise ValueError(f'Invalid beta_1 : {self.betas[1]}')
91-
if not 0.0 <= self.hessian_power <= 1.0:
92-
raise ValueError(f'Invalid hessian_power : {self.hessian_power}')
93-
if self.eps < 0.0:
94-
raise ValueError(f'Invalid eps : {self.eps}')
83+
def validate_parameters(self):
84+
self.validate_learning_rate(self.lr)
85+
self.validate_betas(self.betas)
86+
self.validate_weight_decay(self.weight_decay)
87+
self.validate_hessian_power(self.hessian_power)
88+
self.validate_epsilon(self.eps)
9589

9690
def get_params(self) -> Iterable:
9791
"""Gets all parameters in all param_groups with gradients"""
@@ -104,7 +98,9 @@ def zero_hessian(self):
10498
p.hess.zero_()
10599

106100
def set_hessian(self):
107-
"""Computes the Hutchinson approximation of the hessian trace and accumulates it for each trainable parameter"""
101+
"""Computes the Hutchinson approximation of the hessian trace
102+
and accumulates it for each trainable parameter
103+
"""
108104
params = []
109105
for p in self.get_params():
110106
if p.grad is None:

pytorch_optimizer/adamp.py

Lines changed: 9 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -5,11 +5,12 @@
55
import torch.nn.functional as F
66
from torch.optim.optimizer import Optimizer
77

8+
from pytorch_optimizer.base_optimizer import BaseOptimizer
89
from pytorch_optimizer.gc import centralize_gradient
910
from pytorch_optimizer.types import BETAS, CLOSURE, DEFAULTS, LOSS, PARAMETERS
1011

1112

12-
class AdamP(Optimizer):
13+
class AdamP(Optimizer, BaseOptimizer):
1314
"""
1415
Reference : https://github.com/clovaai/AdamP
1516
Example :
@@ -58,7 +59,7 @@ def __init__(
5859
self.wd_ratio = wd_ratio
5960
self.use_gc = use_gc
6061

61-
self.check_valid_parameters()
62+
self.validate_parameters()
6263

6364
defaults: DEFAULTS = dict(
6465
lr=lr,
@@ -72,19 +73,12 @@ def __init__(
7273
)
7374
super().__init__(params, defaults)
7475

75-
def check_valid_parameters(self):
76-
if self.lr < 0.0:
77-
raise ValueError(f'Invalid learning rate : {self.lr}')
78-
if not 0.0 <= self.betas[0] < 1.0:
79-
raise ValueError(f'Invalid beta_0 : {self.betas[0]}')
80-
if not 0.0 <= self.betas[1] < 1.0:
81-
raise ValueError(f'Invalid beta_1 : {self.betas[1]}')
82-
if self.weight_decay < 0.0:
83-
raise ValueError(f'Invalid weight_decay : {self.weight_decay}')
84-
if self.eps < 0.0:
85-
raise ValueError(f'Invalid eps : {self.eps}')
86-
if not 0.0 <= self.wd_ratio < 1.0:
87-
raise ValueError(f'Invalid wd_ratio : {self.wd_ratio}')
76+
def validate_parameters(self):
77+
self.validate_learning_rate(self.lr)
78+
self.validate_betas(self.betas)
79+
self.validate_weight_decay(self.weight_decay)
80+
self.validate_weight_decay_ratio(self.wd_ratio)
81+
self.validate_epsilon(self.eps)
8882

8983
@staticmethod
9084
def channel_view(x: torch.Tensor) -> torch.Tensor:
Lines changed: 66 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,66 @@
1+
from abc import ABC, abstractmethod
2+
3+
from pytorch_optimizer.types import BETAS
4+
5+
6+
class BaseOptimizer(ABC):
7+
@staticmethod
8+
def validate_learning_rate(learning_rate: float):
9+
if learning_rate < 0.0:
10+
raise ValueError(f'[-] learning rate {learning_rate} must be positive')
11+
12+
@staticmethod
13+
def validate_beta0(beta_0: float):
14+
if not 0.0 <= beta_0 < 1.0:
15+
raise ValueError(f'[-] beta0 {beta_0} must be in the range [0, 1)')
16+
17+
@staticmethod
18+
def validate_betas(betas: BETAS):
19+
if not 0.0 <= betas[0] < 1.0:
20+
raise ValueError(f'[-] beta1 {betas[0]} must be in the range [0, 1)')
21+
if not 0.0 <= betas[1] < 1.0:
22+
raise ValueError(f'[-] beta2 {betas[1]} must be in the range [0, 1)')
23+
24+
@staticmethod
25+
def validate_weight_decay(weight_decay: float):
26+
if weight_decay < 0.0:
27+
raise ValueError(f'[-] weight_decay {weight_decay} must be non-negative')
28+
29+
@staticmethod
30+
def validate_weight_decay_ratio(weight_decay_ratio: float):
31+
if not 0.0 <= weight_decay_ratio < 1.0:
32+
raise ValueError(f'[-] weight_decay_ratio {weight_decay_ratio} must be in the range [0, 1)')
33+
34+
@staticmethod
35+
def validate_hessian_power(hessian_power: float):
36+
if not 0.0 <= hessian_power <= 1.0:
37+
raise ValueError(f'[-] hessian_power {hessian_power} must be in the range [0, 1]')
38+
39+
@staticmethod
40+
def validate_trust_coefficient(trust_coefficient: float):
41+
if trust_coefficient < 0.0:
42+
raise ValueError(f'[-] trust_coefficient {trust_coefficient} must be non-negative')
43+
44+
@staticmethod
45+
def validate_momentum(momentum: float):
46+
if not 0.0 <= momentum < 1.0:
47+
raise ValueError(f'[-] momentum {momentum} must be in the range [0, 1)')
48+
49+
@staticmethod
50+
def validate_lookahead_k(k: int):
51+
if k < 0:
52+
raise ValueError(f'[-] k {k} must be non-negative')
53+
54+
@staticmethod
55+
def validate_rho(rho: float):
56+
if rho < 0.0:
57+
raise ValueError(f'[-] rho {rho} must be non-negative')
58+
59+
@staticmethod
60+
def validate_epsilon(epsilon: float):
61+
if epsilon < 0.0:
62+
raise ValueError(f'[-] epsilon {epsilon} must be non-negative')
63+
64+
@abstractmethod
65+
def validate_parameters(self):
66+
raise NotImplementedError

pytorch_optimizer/diffgrad.py

Lines changed: 9 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -3,10 +3,11 @@
33
import torch
44
from torch.optim.optimizer import Optimizer
55

6+
from pytorch_optimizer.base_optimizer import BaseOptimizer
67
from pytorch_optimizer.types import BETAS, CLOSURE, DEFAULTS, LOSS, PARAMETERS, STATE
78

89

9-
class DiffGrad(Optimizer):
10+
class DiffGrad(Optimizer, BaseOptimizer):
1011
"""
1112
Reference : https://github.com/shivram1987/diffGrad
1213
Example :
@@ -31,7 +32,7 @@ def __init__(
3132
weight_decay: float = 0.0,
3233
adamd_debias_term: bool = False,
3334
):
34-
"""DiffGrad
35+
"""DiffGrad optimizer
3536
:param params: PARAMETERS. iterable of parameters to optimize or dicts defining parameter groups
3637
:param lr: float. learning rate
3738
:param betas: BETAS. coefficients used for computing running averages of gradient and the squared hessian trace
@@ -44,24 +45,18 @@ def __init__(
4445
self.betas = betas
4546
self.weight_decay = weight_decay
4647

47-
self.check_valid_parameters()
48+
self.validate_parameters()
4849

4950
defaults: DEFAULTS = dict(
5051
lr=lr, betas=betas, eps=eps, weight_decay=weight_decay, adamd_debias_term=adamd_debias_term
5152
)
5253
super().__init__(params, defaults)
5354

54-
def check_valid_parameters(self):
55-
if self.lr < 0.0:
56-
raise ValueError(f'Invalid learning rate : {self.lr}')
57-
if self.weight_decay < 0.0:
58-
raise ValueError(f'Invalid weight_decay : {self.weight_decay}')
59-
if not 0.0 <= self.betas[0] < 1.0:
60-
raise ValueError(f'Invalid beta_0 : {self.betas[0]}')
61-
if not 0.0 <= self.betas[1] < 1.0:
62-
raise ValueError(f'Invalid beta_1 : {self.betas[1]}')
63-
if self.eps < 0.0:
64-
raise ValueError(f'Invalid eps : {self.eps}')
55+
def validate_parameters(self):
56+
self.validate_learning_rate(self.lr)
57+
self.validate_betas(self.betas)
58+
self.validate_weight_decay(self.weight_decay)
59+
self.validate_epsilon(self.eps)
6560

6661
def __setstate__(self, state: STATE):
6762
super().__setstate__(state)

0 commit comments

Comments
 (0)