1313
1414class MADGRAD (Optimizer ):
1515 """
16- Reference 1 : https://github.com/facebookresearch/madgrad/blob/main/madgrad/madgrad.py
17- Reference 2 : https://github.com/lessw2020/Best-Deep-Learning-Optimizers/blob/master/madgrad/madgrad_wd.py
16+ Reference 1 : https://github.com/facebookresearch/madgrad
17+ Reference 2 : https://github.com/lessw2020/Best-Deep-Learning-Optimizers
1818 Example :
1919 from pytorch_optimizer import MADGRAD
2020 ...
@@ -36,12 +36,16 @@ def __init__(
3636 weight_decay : float = 0.0 ,
3737 eps : float = 1e-6 ,
3838 ):
39- """A Momentumized, Adaptive, Dual Averaged Gradient Method for Stochastic (slightly modified)
40- :param params: PARAMS. iterable of parameters to optimize or dicts defining parameter groups
39+ """A Momentumized, Adaptive, Dual Averaged Gradient Method
40+ for Stochastic (slightly modified)
41+ :param params: PARAMS. iterable of parameters to optimize
42+ or dicts defining parameter groups
4143 :param lr: float. learning rate.
42- :param eps: float. term added to the denominator to improve numerical stability
44+ :param eps: float. term added to the denominator
45+ to improve numerical stability
4346 :param weight_decay: float. weight decay (L2 penalty)
44- MADGRAD optimizer requires less weight decay than other methods, often as little as zero
47+ MADGRAD optimizer requires less weight decay than other methods,
48+ often as little as zero
4549 On sparse problems both weight_decay and momentum should be set to 0.
4650 """
4751 self .lr = lr
@@ -57,13 +61,13 @@ def __init__(
5761 super ().__init__ (params , defaults )
5862
5963 def check_valid_parameters (self ):
60- if 0.0 > self . lr :
64+ if self . lr < 0.0 :
6165 raise ValueError (f'Invalid learning rate : { self .lr } ' )
62- if 0.0 > self . eps :
66+ if self . eps < 0.0 :
6367 raise ValueError (f'Invalid eps : { self .eps } ' )
64- if 0.0 > self . weight_decay :
68+ if self . weight_decay < 0.0 :
6569 raise ValueError (f'Invalid weight_decay : { self .weight_decay } ' )
66- if 0.0 > self .momentum or 1.0 <= self . momentum :
70+ if not 0.0 < self .momentum <= 1.0 :
6771 raise ValueError (f'Invalid momentum : { self .momentum } ' )
6872
6973 @property
@@ -79,8 +83,8 @@ def step(self, closure: CLOSURE = None) -> LOSS:
7983 if closure is not None :
8084 loss = closure ()
8185
82- # step counter must be stored in state to ensure correct behavior under
83- # optimizer sharding
86+ # step counter must be stored in state to
87+ # ensure correct behavior under optimizer sharding
8488 if 'k' not in self .state :
8589 self .state ['k' ] = torch .tensor ([0 ], dtype = torch .long )
8690
0 commit comments