Skip to content

Commit f9f12bd

Browse files
committed
changed dampening to match Xilin's suggestion
Signed-off-by: mikail <mkhona@nvidia.com>
1 parent a8ec40e commit f9f12bd

File tree

1 file changed

+7
-4
lines changed

1 file changed

+7
-4
lines changed

emerging_optimizers/psgd/psgd.py

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -65,7 +65,7 @@ def __init__(
6565
precond_lr: float = 0.1,
6666
precond_init_scale: float = 1.0,
6767
damping_noise_scale: float = 0.1,
68-
min_precond_lr: float = 0.3,
68+
min_precond_lr: float = 0.01,
6969
warmup_steps: int = 10000,
7070
max_update_rms: float = 0.0,
7171
) -> None:
@@ -182,7 +182,7 @@ def _update_precond_procrustes(
182182
q_list: List[torch.Tensor],
183183
lip_const_list: List[torch.Tensor],
184184
exp_avg: torch.Tensor,
185-
damping_noise_scale: float,
185+
damping_noise_scale: float = 1e-9,
186186
precond_lr: float = 0.1,
187187
beta_lip: float = 0.9,
188188
) -> Tuple[List[torch.Tensor], List[torch.Tensor]]:
@@ -200,7 +200,10 @@ def _update_precond_procrustes(
200200
q_list: List of Kronecker factors.
201201
lip_const_list: List of Lipschitz constants for the Kronecker factors.
202202
"""
203-
pg = apply_preconditioner(q_list, torch.add(exp_avg, torch.randn_like(exp_avg) * damping_noise_scale, alpha=1.0))
203+
dampened_momentum = exp_avg + (
204+
damping_noise_scale + torch.finfo(exp_avg.dtype).eps * exp_avg.abs()
205+
) * torch.randn_like(exp_avg)
206+
pg = apply_preconditioner(q_list, dampened_momentum)
204207
total_numel = pg.numel()
205208
updated_q_list: List[torch.Tensor] = []
206209
updated_lip_const_list: List[torch.Tensor] = []
@@ -282,7 +285,7 @@ def _update_1d_preconditioner(
282285
return q, lip_const
283286

284287

285-
def _get_precond_lr(precond_lr: float, step: int, min_precond_lr: float = 0.3, warmup_steps: int = 10000) -> float:
288+
def _get_precond_lr(precond_lr: float, step: int, min_precond_lr: float = 0.01, warmup_steps: int = 10000) -> float:
286289
r"""Helper function to get preconditioner learning rate for this optimization step based on a square root schedule.
287290
288291
Decaying from a higher lr down to min_precond_lr improves accuracy.

0 commit comments

Comments
 (0)