Skip to content

Commit dd81eb5

Browse files
committed
refactor: matrix_power
1 parent e73d9c3 commit dd81eb5

File tree

1 file changed

+2
-1
lines changed

1 file changed

+2
-1
lines changed

pytorch_optimizer/utils.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -162,6 +162,7 @@ def matrix_power(matrix: torch.Tensor, power: float) -> torch.Tensor:
162162
matrix_device = matrix.device
163163

164164
# use CPU for svd for speed up
165-
u, s, v = torch.svd(matrix.cpu())
165+
u, s, vh = torch.linalg.svd(matrix.cpu(), full_matrices=False)
166+
v = vh.transpose(-2, -1).conj()
166167

167168
return (u @ s.pow_(power).diag() @ v.t()).to(matrix_device)

0 commit comments

Comments
 (0)