@@ -406,7 +406,7 @@ class SGDSaI(BaseOptimizer):
406406
407407 :param params: PARAMETERS. iterable of parameters to optimize or dicts defining parameter groups.
408408 :param lr: float. learning rate.
409- :param momentum: float. momentum factor (0.0 = SignSGD, >0 = Signum) .
409+ :param momentum: float. coefficients used for computing running averages of gradient .
410410 :param weight_decay: float. weight decay (L2 penalty).
411411 :param weight_decouple: bool. the optimizer uses decoupled weight decay as in AdamW.
412412 :param eps: float. term added to the denominator to improve numerical stability.
@@ -415,7 +415,7 @@ class SGDSaI(BaseOptimizer):
415415 def __init__ (
416416 self ,
417417 params : PARAMETERS ,
418- lr : float = 1e-3 ,
418+ lr : float = 1e-2 ,
419419 momentum : float = 0.9 ,
420420 weight_decay : float = 1e-2 ,
421421 weight_decouple : bool = True ,
@@ -468,10 +468,11 @@ def warmup_step(self, closure: CLOSURE = None) -> LOSS:
468468 raise NoSparseGradientError (str (self ))
469469
470470 sigma = grad .std ().nan_to_num_ ()
471- grad_norm_snr = grad .norm ()
472- grad_norm_snr .div_ (sigma .add_ (group ['eps' ]))
471+ grad_norm = grad .norm ()
473472
474- self .state [p ]['gsnr' ] = grad_norm_snr
473+ g_snr = grad_norm .div_ (sigma .add_ (group ['eps' ])) if sigma != 0.0 else grad_norm
474+
475+ self .state [p ]['gsnr' ] = g_snr
475476
476477 self .has_warmup = True
477478
@@ -488,7 +489,7 @@ def step(self, closure: CLOSURE = None) -> LOSS:
488489 loss = closure ()
489490
490491 for group in self .param_groups :
491- momentum = group ['momentum' ]
492+ momentum : float = group ['momentum' ]
492493 for p in group ['params' ]:
493494 if p .grad is None :
494495 continue
@@ -506,8 +507,6 @@ def step(self, closure: CLOSURE = None) -> LOSS:
506507 else :
507508 buf = grad
508509
509- step_size = group ['lr' ] * state ['gsnr' ]
510-
511510 self .apply_weight_decay (
512511 p ,
513512 grad ,
@@ -517,6 +516,6 @@ def step(self, closure: CLOSURE = None) -> LOSS:
517516 False ,
518517 )
519518
520- p .add_ (buf , alpha = - step_size )
519+ p .add_ (buf , alpha = - group [ 'lr' ] * state [ 'gsnr' ] )
521520
522521 return loss
0 commit comments