1313
1414class MADGRAD (Optimizer ):
1515 """
16- Reference : https://github.com/facebookresearch/madgrad/blob/main/madgrad/madgrad.py
16+ Reference 1 : https://github.com/facebookresearch/madgrad/blob/main/madgrad/madgrad.py
17+ Reference 2 : https://github.com/lessw2020/Best-Deep-Learning-Optimizers/blob/master/madgrad/madgrad_wd.py
1718 Example :
1819 from pytorch_optimizer import MADGRAD
1920 ...
@@ -35,11 +36,13 @@ def __init__(
3536 weight_decay : float = 0.0 ,
3637 eps : float = 1e-6 ,
3738 ):
38- """A Momentumized, Adaptive, Dual Averaged Gradient Method for Stochastic
39+ """A Momentumized, Adaptive, Dual Averaged Gradient Method for Stochastic (slightly modified)
3940 :param params: PARAMS. iterable of parameters to optimize or dicts defining parameter groups
4041 :param lr: float. learning rate.
4142 :param eps: float. term added to the denominator to improve numerical stability
4243 :param weight_decay: float. weight decay (L2 penalty)
44+ MADGRAD optimizer requires less weight decay than other methods, often as little as zero
45+ On sparse problems both weight_decay and momentum should be set to 0.
4346 """
4447 self .lr = lr
4548 self .momentum = momentum
@@ -72,11 +75,6 @@ def supports_flat_params(self) -> bool:
7275 return True
7376
7477 def step (self , closure : CLOSURE = None ) -> LOSS :
75- """Performs a single optimization step.
76- Arguments:
77- closure (callable, optional): A closure that reevaluates the model
78- and returns the loss.
79- """
8078 loss : LOSS = None
8179 if closure is not None :
8280 loss = closure ()
@@ -124,7 +122,11 @@ def step(self, closure: CLOSURE = None) -> LOSS:
124122 'weight_decay option is not compatible with sparse gradients'
125123 )
126124
127- grad .add_ (p .data , alpha = decay )
125+ # original implementation
126+ # grad.add_(p.data, alpha=decay)
127+
128+ # Apply weight decay - L2 / AdamW style
129+ p .data .mul_ (1 - lr * decay )
128130
129131 if grad .is_sparse :
130132 grad = grad .coalesce ()
@@ -174,16 +176,13 @@ def step(self, closure: CLOSURE = None) -> LOSS:
174176 grad_sum_sq .addcmul_ (grad , grad , value = _lambda )
175177 rms = grad_sum_sq .pow (1 / 3 ).add_ (eps )
176178
177- # Update s
178179 s .data .add_ (grad , alpha = _lambda )
179180
180- # Step
181181 if momentum == 0 :
182182 p .data .copy_ (x0 .addcdiv (s , rms , value = - 1 ))
183183 else :
184184 z = x0 .addcdiv (s , rms , value = - 1 )
185185
186- # p is a moving average of z
187186 p .data .mul_ (1 - ck ).add_ (z , alpha = ck )
188187
189188 self .state ['k' ] += 1
0 commit comments