@@ -27,18 +27,21 @@ def __init__(
2727 weight_decay : float = 0.0 ,
2828 momentum : float = 0.9 ,
2929 trust_coefficient : float = 0.001 ,
30+ eps : float = 1e-6 ,
3031 ):
3132 """LARS optimizer, no rate scaling or weight decay for parameters <= 1D
3233 :param params: PARAMETERS. iterable of parameters to optimize or dicts defining parameter groups
3334 :param lr: float. learning rate
3435 :param weight_decay: float. weight decay (L2 penalty)
3536 :param momentum: float. momentum
3637 :param trust_coefficient: float. trust_coefficient
38+ :param eps: float. epsilon
3739 """
3840 self .lr = lr
3941 self .weight_decay = weight_decay
4042 self .momentum = momentum
4143 self .trust_coefficient = trust_coefficient
44+ self .eps = eps
4245
4346 self .check_valid_parameters ()
4447
@@ -59,6 +62,8 @@ def check_valid_parameters(self):
5962 raise ValueError (f'Invalid momentum : { self .momentum } ' )
6063 if self .trust_coefficient < 0.0 :
6164 raise ValueError (f'Invalid trust_coefficient : { self .trust_coefficient } ' )
65+ if self .eps < 0.0 :
66+ raise ValueError (f'Invalid eps : { self .eps } ' )
6267
6368 @torch .no_grad ()
6469 def step (self , closure : CLOSURE = None ) -> LOSS :
@@ -84,7 +89,7 @@ def step(self, closure: CLOSURE = None) -> LOSS:
8489
8590 q = torch .where (
8691 param_norm > 0.0 ,
87- torch .where (update_norm > 0 , (g ['trust_coefficient' ] * param_norm / update_norm ), one ),
92+ torch .where (update_norm > 0.0 , (g ['trust_coefficient' ] * param_norm / update_norm ), one ),
8893 one ,
8994 )
9095 dp = dp .mul (q )
0 commit comments