Skip to content

Commit 74db1eb

Browse files
authored
[Fix] Use a smaller eps value when adding to exp_avg_sq in apply_ams_bound() (#398)
* docs: v3.6.2 changelog * update: exp_avg_sq_eps * docs: v3.6.2 changelog
1 parent b0146c2 commit 74db1eb

File tree

2 files changed

+12
-3
lines changed

2 files changed

+12
-3
lines changed

docs/changelogs/v3.6.2.md

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,3 +4,7 @@
44

55
* Implement `AdaMuon` optimizer. (#394, #395)
66
* [Adaptive Muon Optimizer](https://arxiv.org/abs/2507.11005v1)
7+
8+
### Fix
9+
10+
* Adjust the value of `eps` to the fixed value `1e-15` when adding to `exp_avg_sq`. (#397, #398)

pytorch_optimizer/base/optimizer.py

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -154,23 +154,28 @@ def apply_weight_decay(
154154

155155
@staticmethod
156156
def apply_ams_bound(
157-
ams_bound: bool, exp_avg_sq: torch.Tensor, max_exp_avg_sq: Optional[torch.Tensor], eps: float
157+
ams_bound: bool,
158+
exp_avg_sq: torch.Tensor,
159+
max_exp_avg_sq: Optional[torch.Tensor],
160+
eps: float,
161+
exp_avg_sq_eps: float = 1e-15,
158162
) -> torch.Tensor:
159163
r"""Apply AMSBound variant.
160164
161165
:param ams_bound: bool. whether to apply AMSBound.
162166
:param exp_avg_sq: torch.Tensor. exp_avg_sq.
163167
:param max_exp_avg_sq: Optional[torch.Tensor]. max_exp_avg_sq.
164168
:param eps: float. epsilon.
169+
:param exp_avg_sq_eps: float. eps value for numerical stability for exp_avg_sq.
165170
"""
166171
if ams_bound:
167172
if torch.is_complex(max_exp_avg_sq):
168173
max_exp_avg_sq = torch.view_as_real(max_exp_avg_sq)
169174

170175
torch.maximum(max_exp_avg_sq, exp_avg_sq, out=max_exp_avg_sq)
171-
de_nom = max_exp_avg_sq.add(eps)
176+
de_nom = max_exp_avg_sq.add(exp_avg_sq_eps)
172177
else:
173-
de_nom = exp_avg_sq.add(eps)
178+
de_nom = exp_avg_sq.add(exp_avg_sq_eps)
174179

175180
return de_nom.sqrt_().add_(eps)
176181

0 commit comments

Comments
 (0)