Skip to content

Commit 969b49e

Browse files
committed
Improve polar a bit
1 parent d4699b9 commit 969b49e

File tree

1 file changed

+15
-11
lines changed

1 file changed

+15
-11
lines changed

src/factorizations/polar.jl

Lines changed: 15 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -7,31 +7,35 @@ using MatrixAlgebraKit:
77
right_polar!,
88
svd_compact!
99

10-
function MatrixAlgebraKit.check_input(
11-
::typeof(left_polar!), A::AbstractBlockSparseMatrix, WP
12-
)
13-
W, P = WP
10+
function MatrixAlgebraKit.check_input(::typeof(left_polar!), A::AbstractBlockSparseMatrix)
1411
@views for I in eachblockstoredindex(A)
1512
m, n = size(A[I])
1613
m >= n ||
1714
throw(ArgumentError("each input matrix block needs at least as many rows as columns"))
18-
# check_input(left_polar!, A[I], (W[I1], P[I2]))
15+
end
16+
return nothing
17+
end
18+
function MatrixAlgebraKit.check_input(::typeof(right_polar!), A::AbstractBlockSparseMatrix)
19+
@views for I in eachblockstoredindex(A)
20+
m, n = size(A[I])
21+
m <= n ||
22+
throw(ArgumentError("each input matrix block needs at least as many columns as rows"))
1923
end
2024
return nothing
2125
end
2226

23-
function MatrixAlgebraKit.left_polar!(A::AbstractBlockSparseMatrix, WP, alg::PolarViaSVD)
24-
check_input(left_polar!, A, WP)
25-
U, S, Vᴴ = svd_compact!(A, alg.svdalg)
27+
function MatrixAlgebraKit.left_polar!(A::AbstractBlockSparseMatrix, alg::PolarViaSVD)
28+
check_input(left_polar!, A)
2629
# TODO: Use more in-place operations here, avoid `copy`.
30+
U, S, Vᴴ = svd_compact!(A, alg.svdalg)
2731
W = U * Vᴴ
2832
P = copy(Vᴴ') * S * Vᴴ
2933
return (W, P)
3034
end
31-
function MatrixAlgebraKit.right_polar!(A::AbstractBlockSparseMatrix, PWᴴ, alg::PolarViaSVD)
32-
check_input(right_polar!, A, PWᴴ)
33-
U, S, Vᴴ = svd_compact!(A, alg.svdalg)
35+
function MatrixAlgebraKit.right_polar!(A::AbstractBlockSparseMatrix, alg::PolarViaSVD)
36+
check_input(right_polar!, A)
3437
# TODO: Use more in-place operations here, avoid `copy`.
38+
U, S, Vᴴ = svd_compact!(A, alg.svdalg)
3539
Wᴴ = U * Vᴴ
3640
P = U * S * copy(U')
3741
return (P, Wᴴ)

0 commit comments

Comments
 (0)