Skip to content

Commit b03cba9

Browse files
committed
Implement blockwise polar decomposition
1 parent eef41d5 commit b03cba9

File tree

3 files changed

+76
-0
lines changed

3 files changed

+76
-0
lines changed

src/BlockSparseArrays.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -48,5 +48,6 @@ include("factorizations/svd.jl")
4848
include("factorizations/truncation.jl")
4949
include("factorizations/qr.jl")
5050
include("factorizations/lq.jl")
51+
include("factorizations/polar.jl")
5152

5253
end

src/factorizations/polar.jl

Lines changed: 53 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,53 @@
1+
using MatrixAlgebraKit:
2+
MatrixAlgebraKit,
3+
PolarViaSVD,
4+
check_input,
5+
default_algorithm,
6+
left_polar!,
7+
right_polar!,
8+
svd_compact!
9+
10+
function MatrixAlgebraKit.check_input(::typeof(left_polar!), A::AbstractBlockSparseMatrix)
11+
@views for I in eachblockstoredindex(A)
12+
m, n = size(A[I])
13+
m >= n ||
14+
throw(ArgumentError("each input matrix block needs at least as many rows as columns"))
15+
end
16+
return nothing
17+
end
18+
function MatrixAlgebraKit.check_input(::typeof(right_polar!), A::AbstractBlockSparseMatrix)
19+
@views for I in eachblockstoredindex(A)
20+
m, n = size(A[I])
21+
m <= n ||
22+
throw(ArgumentError("each input matrix block needs at least as many columns as rows"))
23+
end
24+
return nothing
25+
end
26+
27+
function MatrixAlgebraKit.left_polar!(A::AbstractBlockSparseMatrix, alg::PolarViaSVD)
28+
check_input(left_polar!, A)
29+
# TODO: Use more in-place operations here, avoid `copy`.
30+
U, S, Vᴴ = svd_compact!(A, alg.svdalg)
31+
W = U * Vᴴ
32+
P = copy(Vᴴ') * S * Vᴴ
33+
return (W, P)
34+
end
35+
function MatrixAlgebraKit.right_polar!(A::AbstractBlockSparseMatrix, alg::PolarViaSVD)
36+
check_input(right_polar!, A)
37+
# TODO: Use more in-place operations here, avoid `copy`.
38+
U, S, Vᴴ = svd_compact!(A, alg.svdalg)
39+
Wᴴ = U * Vᴴ
40+
P = U * S * copy(U')
41+
return (P, Wᴴ)
42+
end
43+
44+
function MatrixAlgebraKit.default_algorithm(
45+
::typeof(left_polar!), a::AbstractBlockSparseMatrix; kwargs...
46+
)
47+
return PolarViaSVD(default_algorithm(svd_compact!, a; kwargs...))
48+
end
49+
function MatrixAlgebraKit.default_algorithm(
50+
::typeof(right_polar!), a::AbstractBlockSparseMatrix; kwargs...
51+
)
52+
return PolarViaSVD(default_algorithm(svd_compact!, a; kwargs...))
53+
end

test/test_factorizations.jl

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,12 @@
11
using BlockArrays: Block, BlockedMatrix, BlockedVector, blocks, mortar
22
using BlockSparseArrays: BlockSparseArray, BlockDiagonal, eachblockstoredindex
33
using MatrixAlgebraKit:
4+
left_polar,
45
lq_compact,
56
lq_full,
67
qr_compact,
78
qr_full,
9+
right_polar,
810
svd_compact,
911
svd_full,
1012
svd_trunc,
@@ -215,3 +217,23 @@ end
215217
@test A L * Q
216218
end
217219
end
220+
221+
@testset "left_polar (T=$T)" for T in (Float32, Float64, ComplexF32, ComplexF64)
222+
A = BlockSparseArray{T}(undef, ([3, 4], [2, 3]))
223+
A[Block(1, 1)] = randn(T, 3, 2)
224+
A[Block(2, 2)] = randn(T, 4, 3)
225+
226+
U, C = left_polar(A)
227+
@test U * C A
228+
@test Matrix(U'U) LinearAlgebra.I
229+
end
230+
231+
@testset "right_polar (T=$T)" for T in (Float32, Float64, ComplexF32, ComplexF64)
232+
A = BlockSparseArray{T}(undef, ([2, 3], [3, 4]))
233+
A[Block(1, 1)] = randn(T, 2, 3)
234+
A[Block(2, 2)] = randn(T, 3, 4)
235+
236+
C, U = right_polar(A)
237+
@test C * U A
238+
@test Matrix(U * U') LinearAlgebra.I
239+
end

0 commit comments

Comments
 (0)