@@ -42,7 +42,7 @@ def __init__(
4242 eps : float = 0.0 ,
4343 ):
4444 self .validate_learning_rate (lr )
45- self .validate_range (momentum , 'momentum' , 0.0 , 1.0 )
45+ self .validate_range (momentum , 'momentum' , 0.0 , 1.0 , range_type = '[)' )
4646 self .validate_non_negative (weight_decay , 'weight_decay' )
4747 self .validate_non_negative (eps , 'eps' )
4848
@@ -85,14 +85,11 @@ def step(self, closure: CLOSURE = None) -> LOSS:
8585 loss = closure ()
8686
8787 group = self .param_groups [0 ]
88-
89- lr , momentum , growth_rate = group ['lr' ], group ['momentum' ], group ['growth_rate' ]
90-
91- d = group ['d' ]
92- d_lr = float (d * lr )
93-
9488 device = group ['params' ][0 ].device
9589
90+ d , lr = group ['d' ], group ['lr' ]
91+ d_lr : float = d * lr
92+
9693 g_sq = torch .tensor ([0.0 ], device = device )
9794 sk_sq_weighted_change = torch .tensor ([0.0 ], device = device )
9895 sk_l1_change = torch .tensor ([0.0 ], device = device )
@@ -199,7 +196,7 @@ def step(self, closure: CLOSURE = None) -> LOSS:
199196
200197 if lr > 0.0 :
201198 d_hat = (sk_sq_weighted - gsq_weighted ) / sk_l1
202- d = group ['d' ] = max (d , min (d_hat , d * growth_rate ))
199+ d = group ['d' ] = max (d , min (d_hat , d * group [ ' growth_rate' ] ))
203200
204201 for group in self .param_groups :
205202 group ['gsq_weighted' ] = gsq_weighted
@@ -212,11 +209,10 @@ def step(self, closure: CLOSURE = None) -> LOSS:
212209 continue
213210
214211 grad = p .grad
212+
215213 state = self .state [p ]
216214
217- alpha_k = state ['alpha_k' ]
218- sk = state ['sk' ]
219- x0 = state ['x0' ]
215+ alpha_k , sk , x0 = state ['alpha_k' ], state ['sk' ], state ['x0' ]
220216
221217 if grad .is_sparse :
222218 grad = grad .coalesce ()
@@ -232,10 +228,10 @@ def step(self, closure: CLOSURE = None) -> LOSS:
232228 loc_delta = torch .sparse_coo_tensor (grad .indices (), loc_delta_masked , grad .shape )
233229 p .add_ (loc_delta )
234230 else :
235- z = x0 - sk .div (torch .sqrt (alpha_k ) + group ['eps' ])
231+ z = x0 - sk .div (alpha_k .sqrt (). add_ ( group ['eps' ]) )
236232
237- if momentum > 0.0 :
238- p .mul_ (momentum ).add_ (z , alpha = 1.0 - momentum )
233+ if group [ ' momentum' ] > 0.0 :
234+ p .mul_ (group [ ' momentum' ] ).add_ (z , alpha = 1.0 - group [ ' momentum' ] )
239235 else :
240236 p .copy_ (z )
241237
0 commit comments