Skip to content

Commit bf2173e

Browse files
committed
[WIP] Add support for block sparse QR decomposition
1 parent d360f75 commit bf2173e

File tree

2 files changed

+203
-0
lines changed

2 files changed

+203
-0
lines changed

src/BlockSparseArrays.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -46,5 +46,6 @@ include("BlockArraysSparseArraysBaseExt/BlockArraysSparseArraysBaseExt.jl")
4646
# factorizations
4747
include("factorizations/svd.jl")
4848
include("factorizations/truncation.jl")
49+
include("factorizations/qr.jl")
4950

5051
end

src/factorizations/qr.jl

Lines changed: 202 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,202 @@
1+
using MatrixAlgebraKit: MatrixAlgebraKit, qr_compact!, qr_full!
2+
3+
# TODO: this is a hardcoded for now to get around this function not being defined in the
4+
# type domain
5+
function MatrixAlgebraKit.default_qr_algorithm(A::AbstractBlockSparseMatrix; kwargs...)
6+
blocktype(A) <: StridedMatrix{<:LinearAlgebra.BLAS.BlasFloat} ||
7+
error("unsupported type: $(blocktype(A))")
8+
alg = MatrixAlgebraKit.LAPACK_HouseholderQR(; kwargs...)
9+
return BlockPermutedDiagonalAlgorithm(alg)
10+
end
11+
12+
function MatrixAlgebraKit.initialize_output(
13+
::typeof(qr_compact!), A::AbstractBlockSparseMatrix, alg::BlockPermutedDiagonalAlgorithm
14+
)
15+
bm, bn = blocksize(A)
16+
bmn = min(bm, bn)
17+
18+
brows = blocklengths(axes(A, 1))
19+
bcols = blocklengths(axes(A, 2))
20+
rlengths = Vector{Int}(undef, bmn)
21+
22+
# fill in values for blocks that are present
23+
bIs = collect(eachblockstoredindex(A))
24+
browIs = Int.(first.(Tuple.(bIs)))
25+
bcolIs = Int.(last.(Tuple.(bIs)))
26+
for bI in eachblockstoredindex(A)
27+
row, col = Int.(Tuple(bI))
28+
nrows = brows[row]
29+
ncols = bcols[col]
30+
rlengths[col] = min(nrows, ncols)
31+
end
32+
33+
# fill in values for blocks that aren't present, pairing them in order of occurence
34+
# this is a convention, which at least gives the expected results for blockdiagonal
35+
emptyrows = setdiff(1:bm, browIs)
36+
emptycols = setdiff(1:bn, bcolIs)
37+
for (row, col) in zip(emptyrows, emptycols)
38+
rlengths[col] = min(brows[row], bcols[col])
39+
end
40+
41+
r_axis = blockedrange(rlengths)
42+
Q = similar(A, axes(A, 1), r_axis)
43+
R = similar(A, r_axis, axes(A, 2))
44+
45+
# allocate output
46+
for bI in eachblockstoredindex(A)
47+
brow, bcol = Tuple(bI)
48+
Q[brow, bcol], R[bcol, bcol] = MatrixAlgebraKit.initialize_output(
49+
qr_compact!, @view!(A[bI]), alg.alg
50+
)
51+
end
52+
53+
# allocate output for blocks that aren't present -- do we also fill identities here?
54+
for (row, col) in zip(emptyrows, emptycols)
55+
@view!(Q[Block(row, col)])
56+
end
57+
58+
return Q, R
59+
end
60+
61+
function MatrixAlgebraKit.initialize_output(
62+
::typeof(qr_full!), A::AbstractBlockSparseMatrix, alg::BlockPermutedDiagonalAlgorithm
63+
)
64+
bm, bn = blocksize(A)
65+
66+
brows = blocklengths(axes(A, 1))
67+
rlengths = copy(brows)
68+
69+
# fill in values for blocks that are present
70+
bIs = collect(eachblockstoredindex(A))
71+
browIs = Int.(first.(Tuple.(bIs)))
72+
bcolIs = Int.(last.(Tuple.(bIs)))
73+
for bI in eachblockstoredindex(A)
74+
row, col = Int.(Tuple(bI))
75+
nrows = brows[row]
76+
rlengths[col] = nrows
77+
end
78+
79+
# fill in values for blocks that aren't present, pairing them in order of occurence
80+
# this is a convention, which at least gives the expected results for blockdiagonal
81+
emptyrows = setdiff(1:bm, browIs)
82+
emptycols = setdiff(1:bn, bcolIs)
83+
for (row, col) in zip(emptyrows, emptycols)
84+
rlengths[col] = brows[row]
85+
end
86+
for (i, k) in enumerate((length(emptycols) + 1):length(emptyrows))
87+
rlengths[bn + i] = brows[emptyrows[k]]
88+
end
89+
90+
r_axis = blockedrange(rlengths)
91+
Q = similar(A, axes(A, 1), r_axis)
92+
R = similar(A, r_axis, axes(A, 2))
93+
94+
# allocate output
95+
for bI in eachblockstoredindex(A)
96+
brow, bcol = Tuple(bI)
97+
Q[brow, bcol], R[bcol, bcol] = MatrixAlgebraKit.initialize_output(
98+
qr_full!, @view!(A[bI]), alg.alg
99+
)
100+
end
101+
102+
# allocate output for blocks that aren't present -- do we also fill identities here?
103+
for (row, col) in zip(emptyrows, emptycols)
104+
@view!(Q[Block(row, col)])
105+
end
106+
# also handle extra rows/cols
107+
for (i, k) in enumerate((length(emptycols) + 1):length(emptyrows))
108+
@view!(Q[Block(emptyrows[k], bn + i)])
109+
end
110+
111+
return Q, R
112+
end
113+
114+
function MatrixAlgebraKit.check_input(
115+
::typeof(qr_compact!), A::AbstractBlockSparseMatrix, QR
116+
)
117+
Q, R = QR
118+
@assert isa(Q, AbstractBlockSparseMatrix) &&
119+
isa(R, AbstractBlockSparseMatrix)
120+
@assert eltype(A) == eltype(Q) == eltype(R)
121+
@assert axes(A, 1) == axes(Q, 1) && axes(A, 2) == axes(R, 2)
122+
@assert axes(Q, 2) == axes(R, 1)
123+
124+
return nothing
125+
end
126+
127+
function MatrixAlgebraKit.check_input(
128+
::typeof(qr_full!), A::AbstractBlockSparseMatrix, QR
129+
)
130+
Q, R = QR
131+
@assert isa(Q, AbstractBlockSparseMatrix) &&
132+
isa(R, AbstractBlockSparseMatrix)
133+
@assert eltype(A) == eltype(Q) == eltype(R)
134+
@assert axes(A, 1) == axes(Q, 1) && axes(A, 2) == axes(R, 2)
135+
@assert axes(Q, 2) == axes(R, 1)
136+
137+
return nothing
138+
end
139+
140+
function MatrixAlgebraKit.qr_compact!(
141+
A::AbstractBlockSparseMatrix, QR, alg::BlockPermutedDiagonalAlgorithm
142+
)
143+
MatrixAlgebraKit.check_input(qr_compact!, A, QR)
144+
Q, R = QR
145+
146+
# do decomposition on each block
147+
for bI in eachblockstoredindex(A)
148+
brow, bcol = Tuple(bI)
149+
qr = (@view!(Q[brow, bcol]), @view!(R[bcol, bcol]))
150+
qr′ = qr_compact!(@view!(A[bI]), qr, alg.alg)
151+
@assert qr === qr′ "qr_compact! might not be in-place"
152+
end
153+
154+
# fill in identities for blocks that aren't present
155+
bIs = collect(eachblockstoredindex(A))
156+
browIs = Int.(first.(Tuple.(bIs)))
157+
bcolIs = Int.(last.(Tuple.(bIs)))
158+
emptyrows = setdiff(1:blocksize(A, 1), browIs)
159+
emptycols = setdiff(1:blocksize(A, 2), bcolIs)
160+
# needs copyto! instead because size(::LinearAlgebra.I) doesn't work
161+
# Q[Block(row, col)] = LinearAlgebra.I
162+
for (row, col) in zip(emptyrows, emptycols)
163+
copyto!(@view!(Q[Block(row, col)]), LinearAlgebra.I)
164+
end
165+
166+
return QR
167+
end
168+
169+
function MatrixAlgebraKit.qr_full!(
170+
A::AbstractBlockSparseMatrix, QR, alg::BlockPermutedDiagonalAlgorithm
171+
)
172+
MatrixAlgebraKit.check_input(qr_full!, A, QR)
173+
Q, R = QR
174+
175+
# do decomposition on each block
176+
for bI in eachblockstoredindex(A)
177+
brow, bcol = Tuple(bI)
178+
qr = (@view!(Q[brow, bcol]), @view!(R[bcol, bcol]))
179+
qr′ = qr_full!(@view!(A[bI]), qr, alg.alg)
180+
@assert qr === qr′ "qr_full! might not be in-place"
181+
end
182+
183+
# fill in identities for blocks that aren't present
184+
bIs = collect(eachblockstoredindex(A))
185+
browIs = Int.(first.(Tuple.(bIs)))
186+
bcolIs = Int.(last.(Tuple.(bIs)))
187+
emptyrows = setdiff(1:blocksize(A, 1), browIs)
188+
emptycols = setdiff(1:blocksize(A, 2), bcolIs)
189+
# needs copyto! instead because size(::LinearAlgebra.I) doesn't work
190+
# Q[Block(row, col)] = LinearAlgebra.I
191+
for (row, col) in zip(emptyrows, emptycols)
192+
copyto!(@view!(Q[Block(row, col)]), LinearAlgebra.I)
193+
end
194+
195+
# also handle extra rows/cols
196+
bn = blocksize(A, 2)
197+
for (i, k) in enumerate((length(emptycols) + 1):length(emptyrows))
198+
copyto!(@view!(Q[Block(emptyrows[k], bn + i)]), LinearAlgebra.I)
199+
end
200+
201+
return QR
202+
end

0 commit comments

Comments
 (0)