Skip to content

Commit b5dc783

Browse files
authored
Implement graded LQ (#38)
1 parent bcc9470 commit b5dc783

File tree

3 files changed

+158
-16
lines changed

3 files changed

+158
-16
lines changed

Project.toml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
name = "GradedArrays"
22
uuid = "bc96ca6e-b7c8-4bb6-888e-c93f838762c2"
33
authors = ["ITensor developers <[email protected]> and contributors"]
4-
version = "0.4.5"
4+
version = "0.4.6"
55

66
[deps]
77
BlockArrays = "8e7c35d0-a365-5155-bbbb-fb81a777f24e"
@@ -25,7 +25,7 @@ GradedArraysTensorAlgebraExt = "TensorAlgebra"
2525

2626
[compat]
2727
BlockArrays = "1.6.0"
28-
BlockSparseArrays = "0.6.2"
28+
BlockSparseArrays = "0.6.5"
2929
Compat = "4.16.0"
3030
DerivableInterfaces = "0.4.4"
3131
FillArrays = "1.13.0"

src/factorizations.jl

Lines changed: 24 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,14 @@ using BlockSparseArrays:
99
mortar_axis
1010
using LinearAlgebra: Diagonal
1111
using MatrixAlgebraKit:
12-
MatrixAlgebraKit, qr_compact!, qr_full!, svd_compact!, svd_full!, svd_trunc!
12+
MatrixAlgebraKit,
13+
lq_compact!,
14+
lq_full!,
15+
qr_compact!,
16+
qr_full!,
17+
svd_compact!,
18+
svd_full!,
19+
svd_trunc!
1320

1421
function BlockSparseArrays.similar_output(
1522
::typeof(svd_compact!), A::GradedMatrix, S_axes, alg::BlockPermutedDiagonalAlgorithm
@@ -73,3 +80,19 @@ function BlockSparseArrays.similar_output(
7380
R = similar(A, R_axis, axes(A, 2))
7481
return Q, R
7582
end
83+
84+
function BlockSparseArrays.similar_output(
85+
::typeof(lq_compact!), A::GradedMatrix, L_axis, alg::BlockPermutedDiagonalAlgorithm
86+
)
87+
L = similar(A, axes(A, 1), L_axis)
88+
Q = similar(A, dual(L_axis), axes(A, 2))
89+
return L, Q
90+
end
91+
92+
function BlockSparseArrays.similar_output(
93+
::typeof(lq_full!), A::GradedMatrix, L_axis, alg::BlockPermutedDiagonalAlgorithm
94+
)
95+
L = similar(A, axes(A, 1), L_axis)
96+
Q = similar(A, dual(L_axis), axes(A, 2))
97+
return L, Q
98+
end

test/test_factorizations.jl

Lines changed: 132 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,19 @@
11
using BlockArrays: Block, blocksizes
22
using GradedArrays: U1, dual, flux, gradedrange, trivial
33
using LinearAlgebra: I, diag, svdvals
4-
using MatrixAlgebraKit: qr_compact, qr_full, svd_compact, svd_full, svd_trunc
5-
using Test: @test, @testset
4+
using MatrixAlgebraKit:
5+
left_orth,
6+
left_polar,
7+
lq_compact,
8+
lq_full,
9+
qr_compact,
10+
qr_full,
11+
right_orth,
12+
right_polar,
13+
svd_compact,
14+
svd_full,
15+
svd_trunc
16+
using Test: @test, @test_broken, @testset
617

718
const elts = (Float32, Float64, ComplexF32, ComplexF64)
819
@testset "svd_compact (eltype=$elt)" for elt in elts
@@ -107,29 +118,33 @@ end
107118
end
108119
end
109120

110-
@testset "qr_compact (eltype=$elt)" for elt in elts
121+
@testset "qr_compact, left_orth (eltype=$elt)" for elt in elts
111122
for i in [2, 3], j in [2, 3], k in [2, 3], l in [2, 3]
112123
r1 = gradedrange([U1(0) => i, U1(1) => j])
113124
r2 = gradedrange([U1(0) => k, U1(1) => l])
114125
a = zeros(elt, r1, dual(r2))
115126
a[Block(2, 2)] = randn(elt, blocksizes(a)[2, 2])
116127
@test flux(a) == U1(0)
117-
q, r = qr_compact(a)
118-
@test q * r a
119-
@test Array(q'q) I
120-
@test flux(q) == trivial(flux(a))
121-
@test flux(r) == flux(a)
128+
for f in (qr_compact, left_orth)
129+
q, r = f(a)
130+
@test q * r a
131+
@test Array(q'q) I
132+
@test flux(q) == trivial(flux(a))
133+
@test flux(r) == flux(a)
134+
end
122135

123136
r1 = gradedrange([U1(0) => i, U1(1) => j])
124137
r2 = gradedrange([U1(0) => k, U1(1) => l])
125138
a = zeros(elt, r1, dual(r2))
126139
a[Block(1, 2)] = randn(elt, blocksizes(a)[1, 2])
127140
@test flux(a) == U1(-1)
128-
q, r = qr_compact(a)
129-
@test q * r a
130-
@test Array(q'q) I
131-
@test flux(q) == trivial(flux(a))
132-
@test flux(r) == flux(a)
141+
for f in (qr_compact, left_orth)
142+
q, r = f(a)
143+
@test q * r a
144+
@test Array(q'q) I
145+
@test flux(q) == trivial(flux(a))
146+
@test flux(r) == flux(a)
147+
end
133148
end
134149
end
135150

@@ -160,3 +175,107 @@ end
160175
@test flux(r) == flux(a)
161176
end
162177
end
178+
179+
@testset "left_polar (eltype=$elt)" for elt in elts
180+
r1 = gradedrange([U1(0) => 3, U1(1) => 4])
181+
r2 = gradedrange([U1(0) => 2, U1(1) => 3])
182+
a = zeros(elt, r1, dual(r2))
183+
a[Block(2, 2)] = randn(elt, blocksizes(a)[2, 2])
184+
@test flux(a) == U1(0)
185+
q, r = left_polar(a)
186+
@test q * r a
187+
@test Array(q'q) I
188+
@test flux(q) == trivial(flux(a))
189+
@test flux(r) == flux(a)
190+
191+
r1 = gradedrange([U1(0) => 3, U1(1) => 4])
192+
r2 = gradedrange([U1(0) => 2, U1(1) => 3])
193+
a = zeros(elt, r1, dual(r2))
194+
a[Block(1, 2)] = randn(elt, blocksizes(a)[1, 2])
195+
@test flux(a) == U1(-1)
196+
q, r = left_polar(a)
197+
@test q * r a
198+
@test Array(q'q) I
199+
@test_broken flux(q) == trivial(flux(a))
200+
@test_broken flux(r) == flux(a)
201+
end
202+
203+
@testset "lq_compact, right_orth (eltype=$elt)" for elt in elts
204+
for i in [2, 3], j in [2, 3], k in [2, 3], l in [2, 3]
205+
r1 = gradedrange([U1(0) => i, U1(1) => j])
206+
r2 = gradedrange([U1(0) => k, U1(1) => l])
207+
a = zeros(elt, r1, dual(r2))
208+
a[Block(2, 2)] = randn(elt, blocksizes(a)[2, 2])
209+
@test flux(a) == U1(0)
210+
l, q = lq_compact(a)
211+
@test l * q a
212+
@test Array(q * q') I
213+
@test flux(l) == flux(a)
214+
@test flux(q) == trivial(flux(a))
215+
end
216+
for i in [2, 3], j in [2, 3], k in [2, 3], l in [2, 3]
217+
r1 = gradedrange([U1(0) => i, U1(1) => j])
218+
r2 = gradedrange([U1(0) => k, U1(1) => l])
219+
a = zeros(elt, r1, dual(r2))
220+
a[Block(1, 2)] = randn(elt, blocksizes(a)[1, 2])
221+
@test flux(a) == U1(-1)
222+
l, q = lq_compact(a)
223+
@test l * q a
224+
@test Array(q * q') I
225+
@test flux(l) == flux(a)
226+
@test flux(q) == trivial(flux(a))
227+
end
228+
end
229+
230+
@testset "lq_full (eltype=$elt)" for elt in elts
231+
for i in [2, 3], j in [2, 3], k in [2, 3], l in [2, 3]
232+
r1 = gradedrange([U1(0) => i, U1(1) => j])
233+
r2 = gradedrange([U1(0) => k, U1(1) => l])
234+
a = zeros(elt, r1, dual(r2))
235+
a[Block(2, 2)] = randn(elt, blocksizes(a)[2, 2])
236+
@test flux(a) == U1(0)
237+
l, q = lq_full(a)
238+
@test l * q a
239+
@test Array(q * q') I
240+
@test Array(q'q) I
241+
@test flux(l) == flux(a)
242+
@test flux(q) == trivial(flux(a))
243+
end
244+
for i in [2, 3], j in [2, 3], k in [2, 3], l in [2, 3]
245+
r1 = gradedrange([U1(0) => i, U1(1) => j])
246+
r2 = gradedrange([U1(0) => k, U1(1) => l])
247+
a = zeros(elt, r1, dual(r2))
248+
a[Block(1, 2)] = randn(elt, blocksizes(a)[1, 2])
249+
@test flux(a) == U1(-1)
250+
l, q = lq_full(a)
251+
@test l * q a
252+
@test Array(q * q') I
253+
@test Array(q'q) I
254+
@test flux(l) == flux(a)
255+
@test flux(q) == trivial(flux(a))
256+
end
257+
end
258+
259+
@testset "right_polar (eltype=$elt)" for elt in elts
260+
r1 = gradedrange([U1(0) => 2, U1(1) => 3])
261+
r2 = gradedrange([U1(0) => 3, U1(1) => 4])
262+
a = zeros(elt, r1, dual(r2))
263+
a[Block(2, 2)] = randn(elt, blocksizes(a)[2, 2])
264+
@test flux(a) == U1(0)
265+
l, q = right_polar(a)
266+
@test l * q a
267+
@test Array(q * q') I
268+
@test flux(l) == flux(a)
269+
@test flux(q) == trivial(flux(a))
270+
271+
r1 = gradedrange([U1(0) => 2, U1(1) => 3])
272+
r2 = gradedrange([U1(0) => 3, U1(1) => 4])
273+
a = zeros(elt, r1, dual(r2))
274+
a[Block(1, 2)] = randn(elt, blocksizes(a)[1, 2])
275+
@test flux(a) == U1(-1)
276+
l, q = right_polar(a)
277+
@test l * q a
278+
@test Array(q * q') I
279+
@test_broken flux(l) == flux(a)
280+
@test_broken flux(q) == trivial(flux(a))
281+
end

0 commit comments

Comments
 (0)