Skip to content

Commit 2efc5e3

Browse files
committed
refactor: MADGRAD optimizer
1 parent 121a3fc commit 2efc5e3

File tree

1 file changed

+16
-12
lines changed

1 file changed

+16
-12
lines changed

pytorch_optimizer/madgrad.py

Lines changed: 16 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -13,8 +13,8 @@
1313

1414
class MADGRAD(Optimizer):
1515
"""
16-
Reference 1 : https://github.com/facebookresearch/madgrad/blob/main/madgrad/madgrad.py
17-
Reference 2 : https://github.com/lessw2020/Best-Deep-Learning-Optimizers/blob/master/madgrad/madgrad_wd.py
16+
Reference 1 : https://github.com/facebookresearch/madgrad
17+
Reference 2 : https://github.com/lessw2020/Best-Deep-Learning-Optimizers
1818
Example :
1919
from pytorch_optimizer import MADGRAD
2020
...
@@ -36,12 +36,16 @@ def __init__(
3636
weight_decay: float = 0.0,
3737
eps: float = 1e-6,
3838
):
39-
"""A Momentumized, Adaptive, Dual Averaged Gradient Method for Stochastic (slightly modified)
40-
:param params: PARAMS. iterable of parameters to optimize or dicts defining parameter groups
39+
"""A Momentumized, Adaptive, Dual Averaged Gradient Method
40+
for Stochastic (slightly modified)
41+
:param params: PARAMS. iterable of parameters to optimize
42+
or dicts defining parameter groups
4143
:param lr: float. learning rate.
42-
:param eps: float. term added to the denominator to improve numerical stability
44+
:param eps: float. term added to the denominator
45+
to improve numerical stability
4346
:param weight_decay: float. weight decay (L2 penalty)
44-
MADGRAD optimizer requires less weight decay than other methods, often as little as zero
47+
MADGRAD optimizer requires less weight decay than other methods,
48+
often as little as zero
4549
On sparse problems both weight_decay and momentum should be set to 0.
4650
"""
4751
self.lr = lr
@@ -57,13 +61,13 @@ def __init__(
5761
super().__init__(params, defaults)
5862

5963
def check_valid_parameters(self):
60-
if 0.0 > self.lr:
64+
if self.lr < 0.0:
6165
raise ValueError(f'Invalid learning rate : {self.lr}')
62-
if 0.0 > self.eps:
66+
if self.eps < 0.0:
6367
raise ValueError(f'Invalid eps : {self.eps}')
64-
if 0.0 > self.weight_decay:
68+
if self.weight_decay < 0.0:
6569
raise ValueError(f'Invalid weight_decay : {self.weight_decay}')
66-
if 0.0 > self.momentum or 1.0 <= self.momentum:
70+
if not 0.0 < self.momentum <= 1.0:
6771
raise ValueError(f'Invalid momentum : {self.momentum}')
6872

6973
@property
@@ -79,8 +83,8 @@ def step(self, closure: CLOSURE = None) -> LOSS:
7983
if closure is not None:
8084
loss = closure()
8185

82-
# step counter must be stored in state to ensure correct behavior under
83-
# optimizer sharding
86+
# step counter must be stored in state to
87+
# ensure correct behavior under optimizer sharding
8488
if 'k' not in self.state:
8589
self.state['k'] = torch.tensor([0], dtype=torch.long)
8690

0 commit comments

Comments
 (0)