Skip to content

Commit 9fadf57

Browse files
committed
update: LARS
1 parent 0a84e59 commit 9fadf57

File tree

1 file changed

+6
-1
lines changed

1 file changed

+6
-1
lines changed

pytorch_optimizer/lars.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -27,18 +27,21 @@ def __init__(
2727
weight_decay: float = 0.0,
2828
momentum: float = 0.9,
2929
trust_coefficient: float = 0.001,
30+
eps: float = 1e-6,
3031
):
3132
"""LARS optimizer, no rate scaling or weight decay for parameters <= 1D
3233
:param params: PARAMETERS. iterable of parameters to optimize or dicts defining parameter groups
3334
:param lr: float. learning rate
3435
:param weight_decay: float. weight decay (L2 penalty)
3536
:param momentum: float. momentum
3637
:param trust_coefficient: float. trust_coefficient
38+
:param eps: float. epsilon
3739
"""
3840
self.lr = lr
3941
self.weight_decay = weight_decay
4042
self.momentum = momentum
4143
self.trust_coefficient = trust_coefficient
44+
self.eps = eps
4245

4346
self.check_valid_parameters()
4447

@@ -59,6 +62,8 @@ def check_valid_parameters(self):
5962
raise ValueError(f'Invalid momentum : {self.momentum}')
6063
if self.trust_coefficient < 0.0:
6164
raise ValueError(f'Invalid trust_coefficient : {self.trust_coefficient}')
65+
if self.eps < 0.0:
66+
raise ValueError(f'Invalid eps : {self.eps}')
6267

6368
@torch.no_grad()
6469
def step(self, closure: CLOSURE = None) -> LOSS:
@@ -84,7 +89,7 @@ def step(self, closure: CLOSURE = None) -> LOSS:
8489

8590
q = torch.where(
8691
param_norm > 0.0,
87-
torch.where(update_norm > 0, (g['trust_coefficient'] * param_norm / update_norm), one),
92+
torch.where(update_norm > 0.0, (g['trust_coefficient'] * param_norm / update_norm), one),
8893
one,
8994
)
9095
dp = dp.mul(q)

0 commit comments

Comments
 (0)