1010
1111
1212class AdaFactor (Optimizer , BaseOptimizer ):
13- r"""Adaptive Learning Rates with Sublinear Memory Cost.
13+ r"""Adaptive Learning Rates with Sublinear Memory Cost with some tweaks .
1414
1515 :param params: PARAMETERS. iterable of parameters to optimize or dicts defining parameter groups.
1616 :param lr: float. learning rate.
17- :param betas: BETAS. coefficients used for computing running averages of gradient and the squared hessian trace.
17+ :param betas: Union[BETAS, None]. coefficients used for computing running averages of gradient and the squared
18+ hessian trace. if betas is None, first momentum will be skipped.
1819 :param decay_rate: float. coefficient used to compute running averages of square gradient.
1920 :param weight_decay: float. weight decay (L2 penalty).
2021 :param weight_decouple: bool. the optimizer uses decoupled weight decay as in AdamW.
@@ -27,6 +28,9 @@ class AdaFactor(Optimizer, BaseOptimizer):
2728 is being used.
2829 :param eps1: float. term added to the denominator to improve numerical stability.
2930 :param eps2: float. term added to the denominator to improve numerical stability.
31+ :param momentum_dtype: torch.dtype. type of momentum variable. In VIT paper observed that storing momentum in
32+ half-precision (bfloat16 type) does not affect training dynamics and has no effect on the outcome while
33+ reducing optimize overhead from 2-fold to 1.5-fold.
3034 """
3135
3236 def __init__ (
@@ -45,6 +49,7 @@ def __init__(
4549 warmup_init : bool = False ,
4650 eps1 : float = 1e-30 ,
4751 eps2 : float = 1e-3 ,
52+ momentum_dtype : torch .dtype = torch .bfloat16 ,
4853 ):
4954 self .validate_learning_rate (lr )
5055 self .validate_betas (betas )
@@ -56,6 +61,7 @@ def __init__(
5661 self .clip_threshold = clip_threshold
5762 self .eps1 = eps1
5863 self .eps2 = eps2
64+ self .momentum_dtype = momentum_dtype
5965
6066 defaults : DEFAULTS = {
6167 'lr' : lr ,
@@ -87,7 +93,8 @@ def reset(self):
8793 grad_shape : Tuple [int , ...] = grad .shape
8894 factored : bool = self .get_options (grad_shape )
8995
90- state ['exp_avg' ] = torch .zeros_like (p )
96+ if group ['betas' ][0 ] is not None :
97+ state ['exp_avg' ] = torch .zeros_like (p , dtype = self .momentum_dtype )
9198
9299 if factored :
93100 state ['exp_avg_sq_row' ] = torch .zeros (grad_shape [:- 1 ], dtype = grad .dtype , device = grad .device )
@@ -149,7 +156,7 @@ def step(self, closure: CLOSURE = None) -> LOSS:
149156 else :
150157 group ['step' ] = 1
151158
152- beta1 , _ = group ['betas' ]
159+ beta1 , beta2 = group ['betas' ]
153160
154161 beta2_t : float = 1.0 - math .pow (group ['step' ], self .decay_rate )
155162
@@ -167,7 +174,8 @@ def step(self, closure: CLOSURE = None) -> LOSS:
167174 factored : bool = self .get_options (grad_shape )
168175
169176 if len (state ) == 0 :
170- state ['exp_avg' ] = torch .zeros_like (p )
177+ if beta1 is not None :
178+ state ['exp_avg' ] = torch .zeros_like (p , dtype = self .momentum_dtype )
171179
172180 if factored :
173181 state ['exp_avg_sq_row' ] = torch .zeros (grad_shape [:- 1 ], dtype = grad .dtype , device = grad .device )
@@ -205,6 +213,8 @@ def step(self, closure: CLOSURE = None) -> LOSS:
205213 else :
206214 exp_avg_sq = state ['exp_avg_sq' ]
207215 exp_avg_sq .mul_ (beta2_t ).add_ (update , alpha = 1.0 - beta2_t )
216+ exp_avg_sq .clamp_ (max = beta2 )
217+
208218 torch .rsqrt (exp_avg_sq , out = update )
209219
210220 if group ['ams_bound' ]:
@@ -216,8 +226,11 @@ def step(self, closure: CLOSURE = None) -> LOSS:
216226
217227 update .div_ ((self .get_rms (update ) / self .clip_threshold ).clamp_ (min = 1.0 )).mul_ (lr )
218228
219- exp_avg = state ['exp_avg' ]
220- exp_avg .mul_ (beta1 ).add_ (update , alpha = 1.0 - beta1 )
229+ if beta1 is not None :
230+ exp_avg = state ['exp_avg' ]
231+ exp_avg .mul_ (beta1 ).add_ (update , alpha = 1.0 - beta1 )
232+
233+ update = exp_avg
221234
222235 self .apply_weight_decay (
223236 p = p ,
@@ -228,6 +241,6 @@ def step(self, closure: CLOSURE = None) -> LOSS:
228241 fixed_decay = group ['fixed_decay' ],
229242 )
230243
231- p .add_ (- exp_avg )
244+ p .add_ (- update )
232245
233246 return loss
0 commit comments