File tree Expand file tree Collapse file tree 1 file changed +14
-3
lines changed
emerging_optimizers/orthogonalized_optimizers Expand file tree Collapse file tree 1 file changed +14
-3
lines changed Original file line number Diff line number Diff line change @@ -75,13 +75,24 @@ def scaled_orthogonalize_fn(grad: torch.Tensor) -> torch.Tensor:
7575
7676
7777def polar_via_svd (A : torch .Tensor , return_p : bool = False ) -> tuple [torch .Tensor , Optional [torch .Tensor ]]:
78- """Compute polar decomposition via SVD"""
78+ """Compute polar decomposition via SVD
7979
80+ Args:
81+ A: The input tensor to compute the polar decomposition of.
82+ return_p: Whether to return the positive-semidefinite part of the polar decomposition. p is not needed
83+ by the MOP optimizer, so by default it is not calculated to save computation. The option is provided to
84+ return full polar decomposition to match the function name.
85+
86+ Returns:
87+ A tuple containing:
88+ - The unitary part of the polar decomposition.
89+ - The positive-semidefinite part of the polar decomposition, if return_p is True.
90+ """
8091 U_svd , S , Vh = torch .linalg .svd (A , full_matrices = False )
8192 U_polar = U_svd @ Vh
8293
8394 if not return_p :
8495 return U_polar , None
8596 else :
86- H = Vh .mH @ torch .diag (S ) @ Vh
87- return U_polar , H
97+ p = Vh .mH @ torch .diag (S ) @ Vh
98+ return U_polar , p
You can’t perform that action at this time.
0 commit comments