Skip to content

Commit 13d0332

Browse files
authored
Implement blockwise polar decomposition (#125)
1 parent eef41d5 commit 13d0332

File tree

4 files changed

+83
-1
lines changed

4 files changed

+83
-1
lines changed

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
name = "BlockSparseArrays"
22
uuid = "2c9a651f-6452-4ace-a6ac-809f4280fbb4"
33
authors = ["ITensor developers <[email protected]> and contributors"]
4-
version = "0.6.3"
4+
version = "0.6.4"
55

66
[deps]
77
Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e"

src/BlockSparseArrays.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -48,5 +48,6 @@ include("factorizations/svd.jl")
4848
include("factorizations/truncation.jl")
4949
include("factorizations/qr.jl")
5050
include("factorizations/lq.jl")
51+
include("factorizations/polar.jl")
5152

5253
end

src/factorizations/polar.jl

Lines changed: 59 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,59 @@
1+
using MatrixAlgebraKit:
2+
MatrixAlgebraKit,
3+
PolarViaSVD,
4+
check_input,
5+
default_algorithm,
6+
left_polar!,
7+
right_polar!,
8+
svd_compact!
9+
10+
function MatrixAlgebraKit.check_input(::typeof(left_polar!), A::AbstractBlockSparseMatrix)
11+
@views for I in eachblockstoredindex(A)
12+
m, n = size(A[I])
13+
m >= n ||
14+
throw(ArgumentError("each input matrix block needs at least as many rows as columns"))
15+
end
16+
return nothing
17+
end
18+
function MatrixAlgebraKit.check_input(::typeof(right_polar!), A::AbstractBlockSparseMatrix)
19+
@views for I in eachblockstoredindex(A)
20+
m, n = size(A[I])
21+
m <= n ||
22+
throw(ArgumentError("each input matrix block needs at least as many columns as rows"))
23+
end
24+
return nothing
25+
end
26+
27+
function MatrixAlgebraKit.left_polar!(A::AbstractBlockSparseMatrix, alg::PolarViaSVD)
28+
check_input(left_polar!, A)
29+
# TODO: Use more in-place operations here, avoid `copy`.
30+
U, S, Vᴴ = svd_compact!(A, alg.svdalg)
31+
W = U * Vᴴ
32+
# TODO: `copy` is required for now because of:
33+
# https://github.com/ITensor/BlockSparseArrays.jl/issues/24
34+
# Remove when that is fixed.
35+
P = copy(Vᴴ') * S * Vᴴ
36+
return (W, P)
37+
end
38+
function MatrixAlgebraKit.right_polar!(A::AbstractBlockSparseMatrix, alg::PolarViaSVD)
39+
check_input(right_polar!, A)
40+
# TODO: Use more in-place operations here, avoid `copy`.
41+
U, S, Vᴴ = svd_compact!(A, alg.svdalg)
42+
Wᴴ = U * Vᴴ
43+
# TODO: `copy` is required for now because of:
44+
# https://github.com/ITensor/BlockSparseArrays.jl/issues/24
45+
# Remove when that is fixed.
46+
P = U * S * copy(U')
47+
return (P, Wᴴ)
48+
end
49+
50+
function MatrixAlgebraKit.default_algorithm(
51+
::typeof(left_polar!), a::AbstractBlockSparseMatrix; kwargs...
52+
)
53+
return PolarViaSVD(default_algorithm(svd_compact!, a; kwargs...))
54+
end
55+
function MatrixAlgebraKit.default_algorithm(
56+
::typeof(right_polar!), a::AbstractBlockSparseMatrix; kwargs...
57+
)
58+
return PolarViaSVD(default_algorithm(svd_compact!, a; kwargs...))
59+
end

test/test_factorizations.jl

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,12 @@
11
using BlockArrays: Block, BlockedMatrix, BlockedVector, blocks, mortar
22
using BlockSparseArrays: BlockSparseArray, BlockDiagonal, eachblockstoredindex
33
using MatrixAlgebraKit:
4+
left_polar,
45
lq_compact,
56
lq_full,
67
qr_compact,
78
qr_full,
9+
right_polar,
810
svd_compact,
911
svd_full,
1012
svd_trunc,
@@ -215,3 +217,23 @@ end
215217
@test A L * Q
216218
end
217219
end
220+
221+
@testset "left_polar (T=$T)" for T in (Float32, Float64, ComplexF32, ComplexF64)
222+
A = BlockSparseArray{T}(undef, ([3, 4], [2, 3]))
223+
A[Block(1, 1)] = randn(T, 3, 2)
224+
A[Block(2, 2)] = randn(T, 4, 3)
225+
226+
U, C = left_polar(A)
227+
@test U * C A
228+
@test Matrix(U'U) LinearAlgebra.I
229+
end
230+
231+
@testset "right_polar (T=$T)" for T in (Float32, Float64, ComplexF32, ComplexF64)
232+
A = BlockSparseArray{T}(undef, ([2, 3], [3, 4]))
233+
A[Block(1, 1)] = randn(T, 2, 3)
234+
A[Block(2, 2)] = randn(T, 3, 4)
235+
236+
C, U = right_polar(A)
237+
@test C * U A
238+
@test Matrix(U * U') LinearAlgebra.I
239+
end

0 commit comments

Comments
 (0)