Skip to content

Commit d63cc94

Browse files
committed
Refactor QR
1 parent 2fa9476 commit d63cc94

File tree

1 file changed

+88
-156
lines changed

1 file changed

+88
-156
lines changed

src/factorizations/qr.jl

Lines changed: 88 additions & 156 deletions
Original file line numberDiff line numberDiff line change
@@ -9,211 +9,143 @@ function MatrixAlgebraKit.default_qr_algorithm(
99
end
1010
end
1111

12-
function similar_output(
13-
::typeof(qr_compact!), A, R_axis, alg::MatrixAlgebraKit.AbstractAlgorithm
14-
)
15-
Q = similar(A, axes(A, 1), R_axis)
16-
R = similar(A, R_axis, axes(A, 2))
17-
return Q, R
12+
function output_type(
13+
f::Union{typeof(qr_compact!),typeof(qr_full!)}, A::Type{<:AbstractMatrix{T}}
14+
) where {T}
15+
QR = Base.promote_op(f, A)
16+
return isconcretetype(QR) ? QR : Tuple{AbstractMatrix{T},AbstractMatrix{T}}
1817
end
1918

20-
function similar_output(
21-
::typeof(qr_full!), A, R_axis, alg::MatrixAlgebraKit.AbstractAlgorithm
19+
function MatrixAlgebraKit.initialize_output(
20+
::typeof(qr_compact!), ::AbstractBlockSparseMatrix, ::BlockPermutedDiagonalAlgorithm
2221
)
23-
Q = similar(A, axes(A, 1), R_axis)
24-
R = similar(A, R_axis, axes(A, 2))
25-
return Q, R
22+
return nothing
2623
end
27-
2824
function MatrixAlgebraKit.initialize_output(
29-
::typeof(qr_compact!), A::AbstractBlockSparseMatrix, alg::BlockPermutedDiagonalAlgorithm
25+
::typeof(qr_compact!), A::AbstractBlockSparseMatrix, alg::BlockDiagonalAlgorithm
3026
)
31-
bm, bn = blocksize(A)
32-
bmn = min(bm, bn)
33-
3427
brows = eachblockaxis(axes(A, 1))
3528
bcols = eachblockaxis(axes(A, 2))
36-
r_axes = similar(brows, bmn)
37-
38-
# fill in values for blocks that are present
39-
bIs = collect(eachblockstoredindex(A))
40-
browIs = Int.(first.(Tuple.(bIs)))
41-
bcolIs = Int.(last.(Tuple.(bIs)))
42-
for bI in eachblockstoredindex(A)
43-
row, col = Int.(Tuple(bI))
44-
len = minimum(length, (brows[row], bcols[col]))
45-
r_axes[col] = brows[row][Base.OneTo(len)]
46-
end
47-
48-
# fill in values for blocks that aren't present, pairing them in order of occurence
49-
# this is a convention, which at least gives the expected results for blockdiagonal
50-
emptyrows = setdiff(1:bm, browIs)
51-
emptycols = setdiff(1:bn, bcolIs)
52-
for (row, col) in zip(emptyrows, emptycols)
53-
len = minimum(length, (brows[row], bcols[col]))
54-
r_axes[col] = brows[row][Base.OneTo(len)]
55-
end
56-
29+
# using the property that zip stops as soon as one of the iterators is exhausted
30+
r_axes = map(splat(infimum), zip(brows, bcols))
5731
r_axis = mortar_axis(r_axes)
58-
Q, R = similar_output(qr_compact!, A, r_axis, alg)
59-
60-
# allocate output
61-
for bI in eachblockstoredindex(A)
62-
brow, bcol = Tuple(bI)
63-
block = @view!(A[bI])
64-
block_alg = block_algorithm(alg, block)
65-
Q[brow, bcol], R[bcol, bcol] = MatrixAlgebraKit.initialize_output(
66-
qr_compact!, block, block_alg
67-
)
68-
end
6932

70-
# allocate output for blocks that aren't present -- do we also fill identities here?
71-
for (row, col) in zip(emptyrows, emptycols)
72-
@view!(Q[Block(row, col)])
73-
end
33+
BQ, BR = fieldtypes(output_type(qr_compact!, blocktype(A)))
34+
Q = similar(A, BlockType(BQ), (axes(A, 1), r_axis))
35+
R = similar(A, BlockType(BR), (r_axis, axes(A, 2)))
7436

7537
return Q, R
7638
end
7739

7840
function MatrixAlgebraKit.initialize_output(
79-
::typeof(qr_full!), A::AbstractBlockSparseMatrix, alg::BlockPermutedDiagonalAlgorithm
41+
::typeof(qr_full!), ::AbstractBlockSparseMatrix, ::BlockPermutedDiagonalAlgorithm
8042
)
81-
bm, bn = blocksize(A)
82-
83-
brows = eachblockaxis(axes(A, 1))
84-
r_axes = copy(brows)
85-
86-
# fill in values for blocks that are present
87-
bIs = collect(eachblockstoredindex(A))
88-
browIs = Int.(first.(Tuple.(bIs)))
89-
bcolIs = Int.(last.(Tuple.(bIs)))
90-
for bI in eachblockstoredindex(A)
91-
row, col = Int.(Tuple(bI))
92-
r_axes[col] = brows[row]
93-
end
94-
95-
# fill in values for blocks that aren't present, pairing them in order of occurence
96-
# this is a convention, which at least gives the expected results for blockdiagonal
97-
emptyrows = setdiff(1:bm, browIs)
98-
emptycols = setdiff(1:bn, bcolIs)
99-
for (row, col) in zip(emptyrows, emptycols)
100-
r_axes[col] = brows[row]
101-
end
102-
for (i, k) in enumerate((length(emptycols) + 1):length(emptyrows))
103-
r_axes[bn + i] = brows[emptyrows[k]]
104-
end
105-
106-
r_axis = mortar_axis(r_axes)
107-
Q, R = similar_output(qr_full!, A, r_axis, alg)
108-
109-
# allocate output
110-
for bI in eachblockstoredindex(A)
111-
brow, bcol = Tuple(bI)
112-
block = @view!(A[bI])
113-
block_alg = block_algorithm(alg, block)
114-
Q[brow, bcol], R[bcol, bcol] = MatrixAlgebraKit.initialize_output(
115-
qr_full!, block, block_alg
116-
)
117-
end
118-
119-
# allocate output for blocks that aren't present -- do we also fill identities here?
120-
for (row, col) in zip(emptyrows, emptycols)
121-
@view!(Q[Block(row, col)])
122-
end
123-
# also handle extra rows/cols
124-
for (i, k) in enumerate((length(emptycols) + 1):length(emptyrows))
125-
@view!(Q[Block(emptyrows[k], bn + i)])
126-
end
127-
43+
return nothing
44+
end
45+
function MatrixAlgebraKit.initialize_output(
46+
::typeof(qr_full!), A::AbstractBlockSparseMatrix, alg::BlockDiagonalAlgorithm
47+
)
48+
BQ, BR = fieldtypes(output_type(qr_compact!, blocktype(A)))
49+
Q = similar(A, BlockType(BQ), (axes(A, 1), axes(A, 1)))
50+
R = similar(A, BlockType(BR), (axes(A, 1), axes(A, 2)))
12851
return Q, R
12952
end
13053

13154
function MatrixAlgebraKit.check_input(
132-
::typeof(qr_compact!), A::AbstractBlockSparseMatrix, QR
55+
::typeof(qr_compact!), A::AbstractBlockSparseMatrix, QR, ::BlockPermutedDiagonalAlgorithm
56+
)
57+
@assert isblockpermuteddiagonal(A)
58+
return nothing
59+
end
60+
function MatrixAlgebraKit.check_input(
61+
::typeof(qr_compact!), A::AbstractBlockSparseMatrix, (Q, R), ::BlockDiagonalAlgorithm
13362
)
134-
Q, R = QR
13563
@assert isa(Q, AbstractBlockSparseMatrix) && isa(R, AbstractBlockSparseMatrix)
13664
@assert eltype(A) == eltype(Q) == eltype(R)
13765
@assert axes(A, 1) == axes(Q, 1) && axes(A, 2) == axes(R, 2)
13866
@assert axes(Q, 2) == axes(R, 1)
139-
67+
@assert isblockdiagonal(A)
14068
return nothing
14169
end
14270

143-
function MatrixAlgebraKit.check_input(::typeof(qr_full!), A::AbstractBlockSparseMatrix, QR)
144-
Q, R = QR
71+
function MatrixAlgebraKit.check_input(
72+
::typeof(qr_full!), A::AbstractBlockSparseMatrix, QR, ::BlockPermutedDiagonalAlgorithm
73+
)
74+
@assert isblockpermuteddiagonal(A)
75+
return nothing
76+
end
77+
function MatrixAlgebraKit.check_input(
78+
::typeof(qr_full!), A::AbstractBlockSparseMatrix, (Q, R), ::BlockDiagonalAlgorithm
79+
)
14580
@assert isa(Q, AbstractBlockSparseMatrix) && isa(R, AbstractBlockSparseMatrix)
14681
@assert eltype(A) == eltype(Q) == eltype(R)
14782
@assert axes(A, 1) == axes(Q, 1) && axes(A, 2) == axes(R, 2)
14883
@assert axes(Q, 2) == axes(R, 1)
149-
84+
@assert isblockdiagonal(A)
15085
return nothing
15186
end
15287

15388
function MatrixAlgebraKit.qr_compact!(
15489
A::AbstractBlockSparseMatrix, QR, alg::BlockPermutedDiagonalAlgorithm
15590
)
156-
MatrixAlgebraKit.check_input(qr_compact!, A, QR)
157-
Q, R = QR
91+
check_input(qr_compact!, A, QR, alg)
92+
Ad, transform_rows, transform_cols = blockdiagonalize(A)
93+
Qd, Rd = qr_compact!(Ad, BlockDiagonalAlgorithm(alg))
94+
Q = transform_rows(Qd)
95+
R = transform_cols(Rd)
96+
return Q, R
97+
end
15898

159-
# do decomposition on each block
160-
for bI in eachblockstoredindex(A)
161-
brow, bcol = Tuple(bI)
162-
qr = (@view!(Q[brow, bcol]), @view!(R[bcol, bcol]))
163-
block = @view!(A[bI])
164-
block_alg = block_algorithm(alg, block)
165-
qr′ = qr_compact!(block, qr, block_alg)
166-
@assert qr === qr′ "qr_compact! might not be in-place"
167-
end
99+
function MatrixAlgebraKit.qr_compact!(
100+
A::AbstractBlockSparseMatrix, (Q, R), alg::BlockDiagonalAlgorithm
101+
)
102+
MatrixAlgebraKit.check_input(qr_compact!, A, (Q, R), alg)
168103

169-
# fill in identities for blocks that aren't present
170-
bIs = collect(eachblockstoredindex(A))
171-
browIs = Int.(first.(Tuple.(bIs)))
172-
bcolIs = Int.(last.(Tuple.(bIs)))
173-
emptyrows = setdiff(1:blocksize(A, 1), browIs)
174-
emptycols = setdiff(1:blocksize(A, 2), bcolIs)
175-
# needs copyto! instead because size(::LinearAlgebra.I) doesn't work
176-
# Q[Block(row, col)] = LinearAlgebra.I
177-
for (row, col) in zip(emptyrows, emptycols)
178-
copyto!(@view!(Q[Block(row, col)]), LinearAlgebra.I)
104+
# do decomposition on each block
105+
for I in 1:min(blocksize(A)...)
106+
bI = Block(I, I)
107+
if isstored(blocks(A), CartesianIndex(I, I)) # TODO: isblockstored
108+
block = @view!(A[bI])
109+
block_alg = block_algorithm(alg, block)
110+
bQ, bR = qr_compact!(block, block_alg)
111+
Q[bI] = bQ
112+
R[bI] = bR
113+
else
114+
copyto!(@view!(Q[bI]), LinearAlgebra.I)
115+
end
179116
end
180117

181-
return QR
118+
return Q, R
182119
end
183120

184121
function MatrixAlgebraKit.qr_full!(
185122
A::AbstractBlockSparseMatrix, QR, alg::BlockPermutedDiagonalAlgorithm
186123
)
187-
MatrixAlgebraKit.check_input(qr_full!, A, QR)
188-
Q, R = QR
189-
190-
# do decomposition on each block
191-
for bI in eachblockstoredindex(A)
192-
brow, bcol = Tuple(bI)
193-
qr = (@view!(Q[brow, bcol]), @view!(R[bcol, bcol]))
194-
block = @view!(A[bI])
195-
block_alg = block_algorithm(alg, block)
196-
qr′ = qr_full!(block, qr, block_alg)
197-
@assert qr === qr′ "qr_full! might not be in-place"
198-
end
124+
check_input(qr_full!, A, QR, alg)
125+
Ad, transform_rows, transform_cols = blockdiagonalize(A)
126+
Qd, Rd = qr_full!(Ad, BlockDiagonalAlgorithm(alg))
127+
Q = transform_rows(Qd)
128+
R = transform_cols(Rd)
129+
return Q, R
130+
end
199131

200-
# fill in identities for blocks that aren't present
201-
bIs = collect(eachblockstoredindex(A))
202-
browIs = Int.(first.(Tuple.(bIs)))
203-
bcolIs = Int.(last.(Tuple.(bIs)))
204-
emptyrows = setdiff(1:blocksize(A, 1), browIs)
205-
emptycols = setdiff(1:blocksize(A, 2), bcolIs)
206-
# needs copyto! instead because size(::LinearAlgebra.I) doesn't work
207-
# Q[Block(row, col)] = LinearAlgebra.I
208-
for (row, col) in zip(emptyrows, emptycols)
209-
copyto!(@view!(Q[Block(row, col)]), LinearAlgebra.I)
210-
end
132+
function MatrixAlgebraKit.qr_full!(
133+
A::AbstractBlockSparseMatrix, (Q, R), alg::BlockDiagonalAlgorithm
134+
)
135+
MatrixAlgebraKit.check_input(qr_full!, A, (Q, R), alg)
211136

212-
# also handle extra rows/cols
213-
bn = blocksize(A, 2)
214-
for (i, k) in enumerate((length(emptycols) + 1):length(emptyrows))
215-
copyto!(@view!(Q[Block(emptyrows[k], bn + i)]), LinearAlgebra.I)
137+
for I in 1:min(blocksize(A)...)
138+
bI = Block(I, I)
139+
if isstored(blocks(A), CartesianIndex(I, I)) # TODO: isblockstored
140+
block = @view!(A[bI])
141+
block_alg = block_algorithm(alg, block)
142+
bQ, bR = qr_full!(block, block_alg)
143+
Q[bI] = bQ
144+
R[bI] = bR
145+
else
146+
copyto!(@view!(Q[bI]), LinearAlgebra.I)
147+
end
216148
end
217149

218-
return QR
150+
return Q, R
219151
end

0 commit comments

Comments
 (0)