Skip to content

Commit 9171d82

Browse files
Yuan-Jinghuirwightman
authored andcommitted
Enhance the numerical stability of the Cautious Optimizer
Enhance numerical stability of the Cautious Optimizer.
1 parent 17b9764 commit 9171d82

File tree

1 file changed

+2
-1
lines changed

1 file changed

+2
-1
lines changed

timm/optim/adamp.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -47,7 +47,8 @@ def projection(p, grad, perturb, delta: float, wd_ratio: float, eps: float, caut
4747
mask = (perturb * grad_perp > 0).to(grad.dtype)
4848
mask.div_(mask.mean().clamp_(min=1e-3))
4949
perturb.mul_(mask)
50-
50+
# Enhance the numerical stability of the Cautious Optimizer
51+
perturb -= p_n * view_func(p_n * perturb).sum(dim=1).reshape(expand_size)
5152
wd = wd_ratio
5253
return perturb, wd
5354

0 commit comments

Comments
 (0)