Skip to content

Commit 047e735

Browse files
committed
fix nuclear norm scale formula
Signed-off-by: Hao Wu <[email protected]>
1 parent 0eaad00 commit 047e735

File tree

1 file changed

+1
-1
lines changed
  • emerging_optimizers/orthogonalized_optimizers

1 file changed

+1
-1
lines changed

emerging_optimizers/orthogonalized_optimizers/mop.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -60,7 +60,7 @@ def scaled_orthogonalize_fn(grad: torch.Tensor) -> torch.Tensor:
6060
scale_factor = muon.get_muon_scale_factor(grad.size(-2), grad.size(-1), mode=scale_mode)
6161
else:
6262
# nuclear norm scaling suggested by PolarGrad paper (https://arxiv.org/pdf/2505.21799)
63-
scale_factor = S.sum().sqrt()
63+
scale_factor = S.sum()
6464
return orth_grad * scale_factor * extra_scale_factor
6565

6666
super().__init__(

0 commit comments

Comments
 (0)