@@ -51,7 +51,7 @@ def __init__(
5151 beta2_cap : float = 0.999 ,
5252 momentum : Optional [float ] = 0.9 ,
5353 momentum_dtype : Union [str , torch .dtype ] = torch .bfloat16 ,
54- eps : float = 1e-30 ,
54+ eps : Optional [ float ] = None ,
5555 weight_decay : float = 0.0 ,
5656 clipping_threshold : Optional [float ] = None ,
5757 unscaled_wd : bool = False ,
@@ -66,6 +66,7 @@ def __init__(
6666 else :
6767 assert momentum_dtype == 'float32' , f'{ momentum_dtype } dtype not supported'
6868 momentum_dtype = torch .float32
69+ # FIXME try to check if momentum dtype is appropriate for device? Torch API not great for this.
6970
7071 defaults = dict (
7172 lr = lr ,
@@ -212,13 +213,17 @@ def _single_tensor_adafactor(
212213 exp_avg_sq = exp_avg_sqs [i ]
213214 exp_avg = exp_avgs [i ]
214215 step_t = state_steps [i ]
216+ if eps is None :
217+ # use square of machine eps for grad dtype if not set
218+ eps = torch .finfo (grad .dtype ).eps ** 2
215219
216220 # Update step
217221 step_t += 1
218222 beta2_t = min (beta2_cap , 1.0 - float (step_t ) ** (- beta2_decay ))
219223 one_minus_beta2_t = 1 - beta2_t
220224
221225 grad_sqr = torch .square (grad ) + eps
226+ # NOTE application of eps (epsilon1) mirrors the optax/big vision/t5x approach
222227 if exp_avg_sq is None :
223228 # factorized second moment
224229 d1 , d0 = _factored_dims (grad .shape , True , min_dim_size_to_factor = min_dim_size_to_factor )
0 commit comments