Skip to content

Commit 7ea5016

Browse files
committed
Change adafactor_bv epsilon default
1 parent 548fdb5 commit 7ea5016

File tree

1 file changed

+6
-1
lines changed

1 file changed

+6
-1
lines changed

timm/optim/adafactor_bv.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -51,7 +51,7 @@ def __init__(
5151
beta2_cap: float = 0.999,
5252
momentum: Optional[float] = 0.9,
5353
momentum_dtype: Union[str, torch.dtype] = torch.bfloat16,
54-
eps: float = 1e-30,
54+
eps: Optional[float] = None,
5555
weight_decay: float = 0.0,
5656
clipping_threshold: Optional[float] = None,
5757
unscaled_wd: bool = False,
@@ -66,6 +66,7 @@ def __init__(
6666
else:
6767
assert momentum_dtype == 'float32', f'{momentum_dtype} dtype not supported'
6868
momentum_dtype = torch.float32
69+
# FIXME try to check if momentum dtype is appropriate for device? Torch API not great for this.
6970

7071
defaults = dict(
7172
lr=lr,
@@ -212,13 +213,17 @@ def _single_tensor_adafactor(
212213
exp_avg_sq = exp_avg_sqs[i]
213214
exp_avg = exp_avgs[i]
214215
step_t = state_steps[i]
216+
if eps is None:
217+
# use square of machine eps for grad dtype if not set
218+
eps = torch.finfo(grad.dtype).eps ** 2
215219

216220
# Update step
217221
step_t += 1
218222
beta2_t = min(beta2_cap, 1.0 - float(step_t) ** (-beta2_decay))
219223
one_minus_beta2_t = 1 - beta2_t
220224

221225
grad_sqr = torch.square(grad) + eps
226+
# NOTE application of eps (epsilon1) mirrors the optax/big vision/t5x approach
222227
if exp_avg_sq is None:
223228
# factorized second moment
224229
d1, d0 = _factored_dims(grad.shape, True, min_dim_size_to_factor=min_dim_size_to_factor)

0 commit comments

Comments
 (0)