File tree Expand file tree Collapse file tree 1 file changed +1
-6
lines changed Expand file tree Collapse file tree 1 file changed +1
-6
lines changed Original file line number Diff line number Diff line change @@ -102,8 +102,6 @@ def __init__(
102102 self .current_lr = lr
103103 self .min_lr = warm_down_min_lr
104104
105- self .param_size : int = 0
106-
107105 defaults : DEFAULTS = dict (
108106 lr = lr ,
109107 betas = betas ,
@@ -241,11 +239,8 @@ def step(self, closure: CLOSURE = None) -> LOSS:
241239 variance_ma .mul_ (beta2 ).addcmul_ (grad , grad , value = 1.0 - beta2 )
242240 variance_ma_sum += (variance_ma / bias_correction2 ).sum ()
243241
244- if self .param_size == 0 :
245- self .param_size = param_size
246-
247242 # stable weight decay
248- variance_normalized = math .sqrt (variance_ma_sum / self . param_size )
243+ variance_normalized = math .sqrt (variance_ma_sum / param_size )
249244 if math .isnan (variance_normalized ):
250245 raise RuntimeError ('hit nan for variance_normalized' )
251246
You can’t perform that action at this time.
0 commit comments