Skip to content

Commit c90bee3

Browse files
committed
rename H to p in polar_via_svd. Add args doc
Signed-off-by: Hao Wu <[email protected]>
1 parent 0245374 commit c90bee3

File tree

1 file changed

+14
-3
lines changed
  • emerging_optimizers/orthogonalized_optimizers

1 file changed

+14
-3
lines changed

emerging_optimizers/orthogonalized_optimizers/mop.py

Lines changed: 14 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -75,13 +75,24 @@ def scaled_orthogonalize_fn(grad: torch.Tensor) -> torch.Tensor:
7575

7676

7777
def polar_via_svd(A: torch.Tensor, return_p: bool = False) -> tuple[torch.Tensor, Optional[torch.Tensor]]:
78-
"""Compute polar decomposition via SVD"""
78+
"""Compute polar decomposition via SVD
7979
80+
Args:
81+
A: The input tensor to compute the polar decomposition of.
82+
return_p: Whether to return the positive-semidefinite part of the polar decomposition. p is not needed
83+
by the MOP optimizer, so by default it is not calculated to save computation. The option is provided to
84+
return full polar decomposition to match the function name.
85+
86+
Returns:
87+
A tuple containing:
88+
- The unitary part of the polar decomposition.
89+
- The positive-semidefinite part of the polar decomposition, if return_p is True.
90+
"""
8091
U_svd, S, Vh = torch.linalg.svd(A, full_matrices=False)
8192
U_polar = U_svd @ Vh
8293

8394
if not return_p:
8495
return U_polar, None
8596
else:
86-
H = Vh.mH @ torch.diag(S) @ Vh
87-
return U_polar, H
97+
p = Vh.mH @ torch.diag(S) @ Vh
98+
return U_polar, p

0 commit comments

Comments
 (0)