1111class Ranger25 (BaseOptimizer ):
1212 r"""Mixin' every fancy optimizer hacks.
1313
14- ADOPT + AdEMAMix + Cautious + StableAdamW + Adam-Atan2
14+ ADOPT + AdEMAMix + Cautious + StableAdamW + Adam-Atan2 + OrthoGrad
1515
1616 :param params: PARAMETERS. iterable of parameters to optimize or dicts defining parameter groups.
1717 :param lr: float. learning rate.
@@ -23,6 +23,7 @@ class Ranger25(BaseOptimizer):
2323 :param t_alpha_beta3: Optional[float]. total number of iterations is preferred when needed.
2424 :param cautious: bool. whether to use the Cautious variant.
2525 :param stable_adamw: bool. whether to use stable AdamW variant.
26+ :param orthograd: bool. whether to use orthograd variant.
2627 :param eps: Optional[float]. term added to the denominator to improve numerical stability. when eps is None and
2728 stable_adamw is False, adam-atan2 feature will be used.
2829 """
@@ -39,6 +40,7 @@ def __init__(
3940 t_alpha_beta3 : Optional [float ] = None ,
4041 cautious : bool = True ,
4142 stable_adamw : bool = True ,
43+ orthograd : bool = True ,
4244 eps : Optional [float ] = 1e-8 ,
4345 ** kwargs ,
4446 ):
@@ -51,6 +53,7 @@ def __init__(
5153
5254 self .cautious = cautious
5355 self .stable_adamw : bool = stable_adamw if isinstance (eps , float ) else False
56+ self .orthograd = orthograd
5457
5558 defaults : DEFAULTS = {
5659 'lr' : lr ,
@@ -97,13 +100,32 @@ def schedule_beta3(t_alpha_beta3: Optional[float], step: int, beta1: float, beta
97100 beta3 ,
98101 )
99102
103+ @torch .no_grad ()
104+ def orthogonalize_gradients (self , params , eps : float = 1e-16 ) -> None :
105+ for p in params :
106+ if p .grad is None or p .grad .is_sparse :
107+ continue
108+
109+ w = p .view (- 1 )
110+ g = p .grad .view (- 1 )
111+
112+ proj = torch .dot (w , g ).div_ (torch .dot (w , w ).add_ (eps ))
113+ g_ortho = g .to (dtype = torch .float32 , copy = True ).sub_ (w , alpha = proj )
114+ g_ortho_scaled = g_ortho .mul_ (g .norm (2 ).div_ (g_ortho .norm (2 ).add_ (eps )))
115+
116+ p .grad .copy_ (g_ortho_scaled .view_as (p .grad ))
117+
100118 @torch .no_grad ()
101119 def step (self , closure : CLOSURE = None ) -> LOSS :
102120 loss : LOSS = None
103121 if closure is not None :
104122 with torch .enable_grad ():
105123 loss = closure ()
106124
125+ if self .orthograd :
126+ for group in self .param_groups :
127+ self .orthogonalize_gradients (group ['params' ])
128+
107129 for group in self .param_groups :
108130 if 'step' in group :
109131 group ['step' ] += 1
0 commit comments