Skip to content

Commit bcc9470

Browse files
authored
Add support for QR decomposition (#37)
1 parent 4f322a8 commit bcc9470

File tree

3 files changed

+88
-17
lines changed

3 files changed

+88
-17
lines changed

Project.toml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
name = "GradedArrays"
22
uuid = "bc96ca6e-b7c8-4bb6-888e-c93f838762c2"
33
authors = ["ITensor developers <[email protected]> and contributors"]
4-
version = "0.4.4"
4+
version = "0.4.5"
55

66
[deps]
77
BlockArrays = "8e7c35d0-a365-5155-bbbb-fb81a777f24e"
@@ -25,7 +25,7 @@ GradedArraysTensorAlgebraExt = "TensorAlgebra"
2525

2626
[compat]
2727
BlockArrays = "1.6.0"
28-
BlockSparseArrays = "0.6.1"
28+
BlockSparseArrays = "0.6.2"
2929
Compat = "4.16.0"
3030
DerivableInterfaces = "0.4.4"
3131
FillArrays = "1.13.0"

src/factorizations.jl

Lines changed: 18 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,8 @@ using BlockSparseArrays:
88
eachblockaxis,
99
mortar_axis
1010
using LinearAlgebra: Diagonal
11-
using MatrixAlgebraKit: MatrixAlgebraKit, svd_compact!, svd_full!, svd_trunc!
11+
using MatrixAlgebraKit:
12+
MatrixAlgebraKit, qr_compact!, qr_full!, svd_compact!, svd_full!, svd_trunc!
1213

1314
function BlockSparseArrays.similar_output(
1415
::typeof(svd_compact!), A::GradedMatrix, S_axes, alg::BlockPermutedDiagonalAlgorithm
@@ -56,3 +57,19 @@ function BlockSparseArrays.similar_truncate(
5657
Ṽᴴ = similar(Vᴴ, dual(v_axis), axes(Vᴴ, 2))
5758
return Ũ, S̃, Ṽᴴ
5859
end
60+
61+
function BlockSparseArrays.similar_output(
62+
::typeof(qr_compact!), A::GradedMatrix, R_axis, alg::BlockPermutedDiagonalAlgorithm
63+
)
64+
Q = similar(A, axes(A, 1), dual(R_axis))
65+
R = similar(A, R_axis, axes(A, 2))
66+
return Q, R
67+
end
68+
69+
function BlockSparseArrays.similar_output(
70+
::typeof(qr_full!), A::GradedMatrix, R_axis, alg::BlockPermutedDiagonalAlgorithm
71+
)
72+
Q = similar(A, axes(A, 1), dual(R_axis))
73+
R = similar(A, R_axis, axes(A, 2))
74+
return Q, R
75+
end

test/test_factorizations.jl

Lines changed: 68 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
using BlockArrays: Block, blocksizes
2-
using GradedArrays: U1, dual, flux, gradedrange
2+
using GradedArrays: U1, dual, flux, gradedrange, trivial
33
using LinearAlgebra: I, diag, svdvals
4-
using MatrixAlgebraKit: svd_compact, svd_full, svd_trunc
4+
using MatrixAlgebraKit: qr_compact, qr_full, svd_compact, svd_full, svd_trunc
55
using Test: @test, @testset
66

77
const elts = (Float32, Float64, ComplexF32, ComplexF64)
@@ -17,9 +17,9 @@ const elts = (Float32, Float64, ComplexF32, ComplexF64)
1717
@test u * s * vᴴ a
1818
@test Array(u'u) I
1919
@test Array(vᴴ * vᴴ') I
20-
@test flux(u) == U1(0)
20+
@test flux(u) == trivial(flux(a))
2121
@test flux(s) == flux(a)
22-
@test flux(vᴴ) == U1(0)
22+
@test flux(vᴴ) == trivial(flux(a))
2323

2424
r1 = gradedrange([U1(0) => i, U1(1) => j])
2525
r2 = gradedrange([U1(0) => k, U1(1) => l])
@@ -31,9 +31,9 @@ const elts = (Float32, Float64, ComplexF32, ComplexF64)
3131
@test u * s * vᴴ a
3232
@test Array(u'u) I
3333
@test Array(vᴴ * vᴴ') I
34-
@test flux(u) == U1(0)
34+
@test flux(u) == trivial(flux(a))
3535
@test flux(s) == flux(a)
36-
@test flux(vᴴ) == U1(0)
36+
@test flux(vᴴ) == trivial(flux(a))
3737
end
3838
end
3939

@@ -50,9 +50,9 @@ end
5050
@test Array(u * u') I
5151
@test Array(vᴴ * vᴴ') I
5252
@test Array(vᴴ'vᴴ) I
53-
@test flux(u) == U1(0)
53+
@test flux(u) == trivial(flux(a))
5454
@test flux(s) == flux(a)
55-
@test flux(vᴴ) == U1(0)
55+
@test flux(vᴴ) == trivial(flux(a))
5656

5757
r1 = gradedrange([U1(0) => i, U1(1) => j])
5858
r2 = gradedrange([U1(0) => k, U1(1) => l])
@@ -65,9 +65,9 @@ end
6565
@test Array(u * u') I
6666
@test Array(vᴴ * vᴴ') I
6767
@test Array(vᴴ'vᴴ) I
68-
@test flux(u) == U1(0)
68+
@test flux(u) == trivial(flux(a))
6969
@test flux(s) == flux(a)
70-
@test flux(vᴴ) == U1(0)
70+
@test flux(vᴴ) == trivial(flux(a))
7171
end
7272
end
7373

@@ -85,9 +85,9 @@ end
8585
@test size(vᴴ) == (1, size(a, 2))
8686
@test Array(u'u) I
8787
@test Array(vᴴ * vᴴ') I
88-
@test flux(u) == U1(0)
88+
@test flux(u) == trivial(flux(a))
8989
@test flux(s) == flux(a)
90-
@test flux(vᴴ) == U1(0)
90+
@test flux(vᴴ) == trivial(flux(a))
9191

9292
r1 = gradedrange([U1(0) => i, U1(1) => j])
9393
r2 = gradedrange([U1(0) => k, U1(1) => l])
@@ -101,8 +101,62 @@ end
101101
@test size(vᴴ) == (1, size(a, 2))
102102
@test Array(u'u) I
103103
@test Array(vᴴ * vᴴ') I
104-
@test flux(u) == U1(0)
104+
@test flux(u) == trivial(flux(a))
105105
@test flux(s) == flux(a)
106-
@test flux(vᴴ) == U1(0)
106+
@test flux(vᴴ) == trivial(flux(a))
107+
end
108+
end
109+
110+
@testset "qr_compact (eltype=$elt)" for elt in elts
111+
for i in [2, 3], j in [2, 3], k in [2, 3], l in [2, 3]
112+
r1 = gradedrange([U1(0) => i, U1(1) => j])
113+
r2 = gradedrange([U1(0) => k, U1(1) => l])
114+
a = zeros(elt, r1, dual(r2))
115+
a[Block(2, 2)] = randn(elt, blocksizes(a)[2, 2])
116+
@test flux(a) == U1(0)
117+
q, r = qr_compact(a)
118+
@test q * r a
119+
@test Array(q'q) I
120+
@test flux(q) == trivial(flux(a))
121+
@test flux(r) == flux(a)
122+
123+
r1 = gradedrange([U1(0) => i, U1(1) => j])
124+
r2 = gradedrange([U1(0) => k, U1(1) => l])
125+
a = zeros(elt, r1, dual(r2))
126+
a[Block(1, 2)] = randn(elt, blocksizes(a)[1, 2])
127+
@test flux(a) == U1(-1)
128+
q, r = qr_compact(a)
129+
@test q * r a
130+
@test Array(q'q) I
131+
@test flux(q) == trivial(flux(a))
132+
@test flux(r) == flux(a)
133+
end
134+
end
135+
136+
@testset "qr_full (eltype=$elt)" for elt in elts
137+
for i in [2, 3], j in [2, 3], k in [2, 3], l in [2, 3]
138+
r1 = gradedrange([U1(0) => i, U1(1) => j])
139+
r2 = gradedrange([U1(0) => k, U1(1) => l])
140+
a = zeros(elt, r1, dual(r2))
141+
a[Block(2, 2)] = randn(elt, blocksizes(a)[2, 2])
142+
@test flux(a) == U1(0)
143+
q, r = qr_full(a)
144+
@test q * r a
145+
@test Array(q'q) I
146+
@test Array(q * q') I
147+
@test flux(q) == trivial(flux(a))
148+
@test flux(r) == flux(a)
149+
150+
r1 = gradedrange([U1(0) => i, U1(1) => j])
151+
r2 = gradedrange([U1(0) => k, U1(1) => l])
152+
a = zeros(elt, r1, dual(r2))
153+
a[Block(1, 2)] = randn(elt, blocksizes(a)[1, 2])
154+
@test flux(a) == U1(-1)
155+
q, r = qr_full(a)
156+
@test q * r a
157+
@test Array(q'q) I
158+
@test Array(q * q') I
159+
@test flux(q) == trivial(flux(a))
160+
@test flux(r) == flux(a)
107161
end
108162
end

0 commit comments

Comments
 (0)