Skip to content

Commit 93d9eb3

Browse files
authored
Improve consistency among couple of orthogonalized optimizers (#79)
Signed-off-by: Hao Wu <[email protected]>
1 parent e363d73 commit 93d9eb3

File tree

2 files changed

+7
-6
lines changed

2 files changed

+7
-6
lines changed

emerging_optimizers/orthogonalized_optimizers/adaptive_muon.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -195,7 +195,7 @@ def step(self, closure: Callable[[], float] | None = None) -> float | None:
195195
for group in self.param_groups:
196196
for p in group["params"]:
197197
if p.dim() != 2:
198-
raise ValueError("AdaptiveMuon only supports 2D parameters")
198+
raise ValueError(f"{self.__class__.__name__} only supports 2D parameters")
199199
grad = p.grad
200200
if grad is None:
201201
continue
@@ -223,7 +223,8 @@ def step(self, closure: Callable[[], float] | None = None) -> float | None:
223223
grad = exp_avg
224224

225225
with utils.fp32_matmul_precision(self.fp32_matmul_prec):
226-
orth_grad = self.scaled_orthogonalize_fn(grad)
226+
group_kwargs = {k: v for k, v in group.items() if k != "params"}
227+
orth_grad = self.orthogonalize(p, grad, **group_kwargs)
227228

228229
update = self._apply_moment2_normalization(
229230
orth_grad=orth_grad,

emerging_optimizers/orthogonalized_optimizers/orthogonalized_optimizer.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -140,8 +140,8 @@ def step(self, closure: Callable[[], float] | None = None) -> float | None:
140140

141141
for group in self.param_groups:
142142
for p in group["params"]:
143-
if p.dim() == 1:
144-
raise ValueError(f"{self.__class__.__name__} does not support 1D parameters")
143+
if p.dim() != 2:
144+
raise ValueError(f"{self.__class__.__name__} only supports 2D parameters")
145145
grad = p.grad
146146
if grad is None:
147147
continue
@@ -172,11 +172,11 @@ def step(self, closure: Callable[[], float] | None = None) -> float | None:
172172

173173
with utils.fp32_matmul_precision(self.fp32_matmul_prec):
174174
group_kwargs = {k: v for k, v in group.items() if k != "params"}
175-
grad = self.orthogonalize(p, grad, **group_kwargs)
175+
orth_grad = self.orthogonalize(p, grad, **group_kwargs)
176176

177177
# perform weight update
178178
# scale is applied to have update RMS == 1
179-
p.add_(grad, alpha=-group["lr"])
179+
p.add_(orth_grad, alpha=-group["lr"])
180180

181181
return loss
182182

0 commit comments

Comments
 (0)