1+ import math
2+
13import torch
24from torch .optim import Optimizer
35
@@ -31,21 +33,24 @@ def __init__(
3133 weight_decay : float = 0.0 ,
3234 adam : bool = False ,
3335 adamd_debias_term : bool = False ,
36+ pre_norm : bool = False ,
3437 ):
3538 """
3639 :param params: PARAMETERS. iterable of parameters to optimize or dicts defining parameter groups
3740 :param lr: float. learning rate
3841 :param betas: BETAS. coefficients used for computing running averages of gradient and the squared hessian trace
42+ :param eps: float. term added to the denominator to improve numerical stability
3943 :param weight_decay: float. weight decay (L2 penalty)
4044 :param adamd_debias_term: bool. Only correct the denominator to avoid inflating step sizes early in training
41- :param eps: float. term added to the denominator to improve numerical stability
45+ :param pre_norm: bool. perform pre-normalization of all gradients
4246 """
4347 self .lr = lr
4448 self .betas = betas
4549 self .weight_decay = weight_decay
4650 self .eps = eps
4751 self .adam = adam
4852 self .adamd_debias_term = adamd_debias_term
53+ self .pre_norm = pre_norm
4954
5055 self .check_valid_parameters ()
5156
@@ -65,16 +70,36 @@ def check_valid_parameters(self):
6570 if self .eps < 0.0 :
6671 raise ValueError (f'Invalid eps : { self .eps } ' )
6772
73+ def get_gradient_norm (self ) -> float :
74+ norm_sq : float = 0.0
75+ for group in self .param_groups :
76+ for p in group ['params' ]:
77+ if p .grad is None :
78+ continue
79+
80+ norm_sq += torch .linalg .norm (p .grad ).item () ** 2
81+
82+ norm = math .sqrt (norm_sq )
83+
84+ return norm
85+
6886 def step (self , closure : CLOSURE = None ) -> float :
6987 loss = None
7088 if closure is not None :
7189 loss = closure ()
7290
91+ grad_norm : float = 1.0
92+ if self .pre_norm :
93+ grad_norm = self .get_gradient_norm ()
94+
7395 for group in self .param_groups :
7496 for p in group ['params' ]:
7597 if p .grad is None :
7698 continue
7799
100+ if self .pre_norm :
101+ p .grad /= grad_norm
102+
78103 grad = p .grad .data
79104 if grad .is_sparse :
80105 raise RuntimeError ('[-] Lamb does not support sparse gradients, consider SparseAdam instead.' )
0 commit comments