Skip to content

Commit 3969290

Browse files
authored
Move 2d shape check to orthogonalize() (#87)
Signed-off-by: Hao Wu <[email protected]>
1 parent d246453 commit 3969290

File tree

1 file changed

+7
-2
lines changed

1 file changed

+7
-2
lines changed

emerging_optimizers/orthogonalized_optimizers/orthogonalized_optimizer.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -147,8 +147,6 @@ def step(self, closure: Callable[[], float] | None = None) -> float | None:
147147

148148
for group in self.param_groups:
149149
for p in group["params"]:
150-
if p.dim() != 2:
151-
raise ValueError(f"{self.__class__.__name__} only supports 2D parameters")
152150
grad = p.grad
153151
if grad is None:
154152
continue
@@ -195,6 +193,11 @@ def orthogonalize(self, p: torch.Tensor, grad: torch.Tensor, **kwargs: Any) -> t
195193
For example, a scaled_orthogonalize_fn function can get attributes from p or from kwargs to determine if
196194
the parameter is a fused parameter and should be split for preconditioning.
197195
196+
Note:
197+
N-D parameters can be supported by overriding this function. For example, convolution weight can be
198+
supported by reshaping to [output_channels, input_channels * kernel_height * kernel_width], i.e. treating
199+
convolution as matrix multiplication with im2col.
200+
198201
Args:
199202
p: The parameter tensor. It is necessary to pass param tensor in addition to momentum because a lot of
200203
information is only available in the param tensor, attributes for example. Although not used in
@@ -205,6 +208,8 @@ def orthogonalize(self, p: torch.Tensor, grad: torch.Tensor, **kwargs: Any) -> t
205208
Returns:
206209
The orthogonalized gradient tensor.
207210
"""
211+
if grad.ndim != 2:
212+
raise ValueError("Only 2D parameters are supported.")
208213
grad = self.scaled_orthogonalize_fn(grad)
209214
return grad
210215

0 commit comments

Comments
 (0)