Skip to content

Commit d4699b9

Browse files
committed
Add polar, add tests
1 parent 1b918eb commit d4699b9

File tree

4 files changed

+117
-11
lines changed

4 files changed

+117
-11
lines changed

src/BlockSparseArrays.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,7 @@ include("BlockArraysSparseArraysBaseExt/BlockArraysSparseArraysBaseExt.jl")
4747
include("factorizations/svd.jl")
4848
include("factorizations/truncation.jl")
4949
include("factorizations/qr.jl")
50+
include("factorizations/polar.jl")
5051
include("factorizations/orthnull.jl")
5152

5253
end

src/factorizations/orthnull.jl

Lines changed: 10 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,11 @@
11
using MatrixAlgebraKit:
2-
MatrixAlgebraKit, left_polar, qr_compact, select_algorithm, svd_compact
2+
MatrixAlgebraKit,
3+
left_polar,
4+
lq_compact,
5+
qr_compact,
6+
right_polar,
7+
select_algorithm,
8+
svd_compact
39

410
function MatrixAlgebraKit.left_orth(
511
A::AbstractBlockSparseMatrix;
@@ -15,9 +21,7 @@ function MatrixAlgebraKit.left_orth(
1521
if kind == :qr
1622
return left_orth_qr(A, alg_qr)
1723
elseif kind == :polar
18-
# TODO: Implement this.
19-
# return left_orth_polar(A, alg_polar)
20-
return left_orth_svd(A, alg_svd)
24+
return left_orth_polar(A, alg_polar)
2125
elseif kind == :svd
2226
return left_orth_svd(A, alg_svd, trunc)
2327
else
@@ -39,7 +43,7 @@ function left_orth_svd(A, alg, trunc::Nothing=nothing)
3943
end
4044

4145
function MatrixAlgebraKit.right_orth(
42-
A;
46+
A::AbstractBlockSparseMatrix;
4347
trunc=nothing,
4448
kind=isnothing(trunc) ? :lq : :svd,
4549
alg_lq=(; positive=true),
@@ -54,9 +58,7 @@ function MatrixAlgebraKit.right_orth(
5458
# return right_orth_lq(A, alg_lq)
5559
return right_orth_svd(A, alg_svd)
5660
elseif kind == :polar
57-
# TODO: Implement this.
58-
# return right_orth_polar(A, alg_polar)
59-
return right_orth_svd(A, alg_svd)
61+
return right_orth_polar(A, alg_polar)
6062
elseif kind == :svd
6163
return right_orth_svd(A, alg_svd, trunc)
6264
else

src/factorizations/polar.jl

Lines changed: 49 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,49 @@
1+
using MatrixAlgebraKit:
2+
MatrixAlgebraKit,
3+
PolarViaSVD,
4+
check_input,
5+
default_algorithm,
6+
left_polar!,
7+
right_polar!,
8+
svd_compact!
9+
10+
function MatrixAlgebraKit.check_input(
11+
::typeof(left_polar!), A::AbstractBlockSparseMatrix, WP
12+
)
13+
W, P = WP
14+
@views for I in eachblockstoredindex(A)
15+
m, n = size(A[I])
16+
m >= n ||
17+
throw(ArgumentError("each input matrix block needs at least as many rows as columns"))
18+
# check_input(left_polar!, A[I], (W[I1], P[I2]))
19+
end
20+
return nothing
21+
end
22+
23+
function MatrixAlgebraKit.left_polar!(A::AbstractBlockSparseMatrix, WP, alg::PolarViaSVD)
24+
check_input(left_polar!, A, WP)
25+
U, S, Vᴴ = svd_compact!(A, alg.svdalg)
26+
# TODO: Use more in-place operations here, avoid `copy`.
27+
W = U * Vᴴ
28+
P = copy(Vᴴ') * S * Vᴴ
29+
return (W, P)
30+
end
31+
function MatrixAlgebraKit.right_polar!(A::AbstractBlockSparseMatrix, PWᴴ, alg::PolarViaSVD)
32+
check_input(right_polar!, A, PWᴴ)
33+
U, S, Vᴴ = svd_compact!(A, alg.svdalg)
34+
# TODO: Use more in-place operations here, avoid `copy`.
35+
Wᴴ = U * Vᴴ
36+
P = U * S * copy(U')
37+
return (P, Wᴴ)
38+
end
39+
40+
function MatrixAlgebraKit.default_algorithm(
41+
::typeof(left_polar!), a::AbstractBlockSparseMatrix; kwargs...
42+
)
43+
return PolarViaSVD(default_algorithm(svd_compact!, a; kwargs...))
44+
end
45+
function MatrixAlgebraKit.default_algorithm(
46+
::typeof(right_polar!), a::AbstractBlockSparseMatrix; kwargs...
47+
)
48+
return PolarViaSVD(default_algorithm(svd_compact!, a; kwargs...))
49+
end

test/test_factorizations.jl

Lines changed: 57 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,17 @@
11
using BlockArrays: Block, BlockedMatrix, BlockedVector, blocks, mortar
22
using BlockSparseArrays: BlockSparseArray, BlockDiagonal, eachblockstoredindex
33
using MatrixAlgebraKit:
4-
qr_compact, qr_full, svd_compact, svd_full, svd_trunc, truncrank, trunctol
4+
left_orth,
5+
left_polar,
6+
qr_compact,
7+
qr_full,
8+
right_orth,
9+
right_polar,
10+
svd_compact,
11+
svd_full,
12+
svd_trunc,
13+
truncrank,
14+
trunctol
515
using LinearAlgebra: LinearAlgebra
616
using Random: Random
717
using Test: @inferred, @testset, @test
@@ -156,7 +166,7 @@ end
156166
end
157167
end
158168

159-
@testset "qr_compact" for T in (Float32, Float64, ComplexF32, ComplexF64)
169+
@testset "qr_compact (T=$T)" for T in (Float32, Float64, ComplexF32, ComplexF64)
160170
for i in [1, 2], j in [1, 2], k in [1, 2], l in [1, 2]
161171
A = BlockSparseArray{T}(undef, ([i, j], [k, l]))
162172
A[Block(1, 1)] = randn(T, i, k)
@@ -167,7 +177,7 @@ end
167177
end
168178
end
169179

170-
@testset "qr_full" for T in (Float32, Float64, ComplexF32, ComplexF64)
180+
@testset "qr_full (T=$T)" for T in (Float32, Float64, ComplexF32, ComplexF64)
171181
for i in [2, 3], j in [2, 3], k in [2, 3], l in [2, 3]
172182
A = BlockSparseArray{T}(undef, ([i, j], [k, l]))
173183
A[Block(1, 1)] = randn(T, i, k)
@@ -181,3 +191,47 @@ end
181191
@test A Q * R
182192
end
183193
end
194+
195+
@testset "left_polar (T=$T)" for T in (Float32, Float64, ComplexF32, ComplexF64)
196+
A = BlockSparseArray{T}(undef, ([3, 4], [2, 3]))
197+
A[Block(1, 1)] = randn(T, 3, 2)
198+
A[Block(2, 2)] = randn(T, 4, 3)
199+
200+
U, C = left_polar(A)
201+
@test U * C A
202+
@test Matrix(U'U) LinearAlgebra.I
203+
end
204+
205+
@testset "right_polar (T=$T)" for T in (Float32, Float64, ComplexF32, ComplexF64)
206+
A = BlockSparseArray{T}(undef, ([2, 3], [3, 4]))
207+
A[Block(1, 1)] = randn(T, 2, 3)
208+
A[Block(2, 2)] = randn(T, 3, 4)
209+
210+
C, U = right_polar(A)
211+
@test C * U A
212+
@test Matrix(U * U') LinearAlgebra.I
213+
end
214+
215+
@testset "left_orth (T=$T)" for T in (Float32, Float64, ComplexF32, ComplexF64)
216+
A = BlockSparseArray{T}(undef, ([3, 4], [2, 3]))
217+
A[Block(1, 1)] = randn(T, 3, 2)
218+
A[Block(2, 2)] = randn(T, 4, 3)
219+
220+
for kind in (:qr, :polar, :svd)
221+
U, C = left_orth(A; kind)
222+
@test U * C A
223+
@test Matrix(U'U) LinearAlgebra.I
224+
end
225+
end
226+
227+
@testset "right_orth (T=$T)" for T in (Float32, Float64, ComplexF32, ComplexF64)
228+
A = BlockSparseArray{T}(undef, ([2, 3], [3, 4]))
229+
A[Block(1, 1)] = randn(T, 2, 3)
230+
A[Block(2, 2)] = randn(T, 3, 4)
231+
232+
for kind in (:qr, :polar, :svd)
233+
C, U = right_orth(A; kind)
234+
@test C * U A
235+
@test Matrix(U * U') LinearAlgebra.I
236+
end
237+
end

0 commit comments

Comments
 (0)