Skip to content

Commit 42d4d2f

Browse files
committed
feature: pre_norm
1 parent 5bc875e commit 42d4d2f

File tree

1 file changed

+26
-1
lines changed

1 file changed

+26
-1
lines changed

pytorch_optimizer/lamb.py

Lines changed: 26 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
1+
import math
2+
13
import torch
24
from torch.optim import Optimizer
35

@@ -31,21 +33,24 @@ def __init__(
3133
weight_decay: float = 0.0,
3234
adam: bool = False,
3335
adamd_debias_term: bool = False,
36+
pre_norm: bool = False,
3437
):
3538
"""
3639
:param params: PARAMETERS. iterable of parameters to optimize or dicts defining parameter groups
3740
:param lr: float. learning rate
3841
:param betas: BETAS. coefficients used for computing running averages of gradient and the squared hessian trace
42+
:param eps: float. term added to the denominator to improve numerical stability
3943
:param weight_decay: float. weight decay (L2 penalty)
4044
:param adamd_debias_term: bool. Only correct the denominator to avoid inflating step sizes early in training
41-
:param eps: float. term added to the denominator to improve numerical stability
45+
:param pre_norm: bool. perform pre-normalization of all gradients
4246
"""
4347
self.lr = lr
4448
self.betas = betas
4549
self.weight_decay = weight_decay
4650
self.eps = eps
4751
self.adam = adam
4852
self.adamd_debias_term = adamd_debias_term
53+
self.pre_norm = pre_norm
4954

5055
self.check_valid_parameters()
5156

@@ -65,16 +70,36 @@ def check_valid_parameters(self):
6570
if self.eps < 0.0:
6671
raise ValueError(f'Invalid eps : {self.eps}')
6772

73+
def get_gradient_norm(self) -> float:
74+
norm_sq: float = 0.0
75+
for group in self.param_groups:
76+
for p in group['params']:
77+
if p.grad is None:
78+
continue
79+
80+
norm_sq += torch.linalg.norm(p.grad).item() ** 2
81+
82+
norm = math.sqrt(norm_sq)
83+
84+
return norm
85+
6886
def step(self, closure: CLOSURE = None) -> float:
6987
loss = None
7088
if closure is not None:
7189
loss = closure()
7290

91+
grad_norm: float = 1.0
92+
if self.pre_norm:
93+
grad_norm = self.get_gradient_norm()
94+
7395
for group in self.param_groups:
7496
for p in group['params']:
7597
if p.grad is None:
7698
continue
7799

100+
if self.pre_norm:
101+
p.grad /= grad_norm
102+
78103
grad = p.grad.data
79104
if grad.is_sparse:
80105
raise RuntimeError('[-] Lamb does not support sparse gradients, consider SparseAdam instead.')

0 commit comments

Comments
 (0)