Skip to content

Commit cecf967

Browse files
committed
Refactor QR
1 parent 2fa9476 commit cecf967

File tree

1 file changed

+89
-158
lines changed

1 file changed

+89
-158
lines changed

src/factorizations/qr.jl

Lines changed: 89 additions & 158 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,4 @@
1-
using MatrixAlgebraKit:
2-
MatrixAlgebraKit, default_qr_algorithm, lq_compact!, lq_full!, qr_compact!, qr_full!
1+
using MatrixAlgebraKit: MatrixAlgebraKit, default_qr_algorithm, qr_compact!, qr_full!
32

43
function MatrixAlgebraKit.default_qr_algorithm(
54
::Type{<:AbstractBlockSparseMatrix}; kwargs...
@@ -9,211 +8,143 @@ function MatrixAlgebraKit.default_qr_algorithm(
98
end
109
end
1110

12-
function similar_output(
13-
::typeof(qr_compact!), A, R_axis, alg::MatrixAlgebraKit.AbstractAlgorithm
14-
)
15-
Q = similar(A, axes(A, 1), R_axis)
16-
R = similar(A, R_axis, axes(A, 2))
17-
return Q, R
11+
function output_type(
12+
f::Union{typeof(qr_compact!),typeof(qr_full!)}, A::Type{<:AbstractMatrix{T}}
13+
) where {T}
14+
QR = Base.promote_op(f, A)
15+
return isconcretetype(QR) ? QR : Tuple{AbstractMatrix{T},AbstractMatrix{T}}
1816
end
1917

20-
function similar_output(
21-
::typeof(qr_full!), A, R_axis, alg::MatrixAlgebraKit.AbstractAlgorithm
18+
function MatrixAlgebraKit.initialize_output(
19+
::typeof(qr_compact!), ::AbstractBlockSparseMatrix, ::BlockPermutedDiagonalAlgorithm
2220
)
23-
Q = similar(A, axes(A, 1), R_axis)
24-
R = similar(A, R_axis, axes(A, 2))
25-
return Q, R
21+
return nothing
2622
end
27-
2823
function MatrixAlgebraKit.initialize_output(
29-
::typeof(qr_compact!), A::AbstractBlockSparseMatrix, alg::BlockPermutedDiagonalAlgorithm
24+
::typeof(qr_compact!), A::AbstractBlockSparseMatrix, alg::BlockDiagonalAlgorithm
3025
)
31-
bm, bn = blocksize(A)
32-
bmn = min(bm, bn)
33-
3426
brows = eachblockaxis(axes(A, 1))
3527
bcols = eachblockaxis(axes(A, 2))
36-
r_axes = similar(brows, bmn)
37-
38-
# fill in values for blocks that are present
39-
bIs = collect(eachblockstoredindex(A))
40-
browIs = Int.(first.(Tuple.(bIs)))
41-
bcolIs = Int.(last.(Tuple.(bIs)))
42-
for bI in eachblockstoredindex(A)
43-
row, col = Int.(Tuple(bI))
44-
len = minimum(length, (brows[row], bcols[col]))
45-
r_axes[col] = brows[row][Base.OneTo(len)]
46-
end
47-
48-
# fill in values for blocks that aren't present, pairing them in order of occurence
49-
# this is a convention, which at least gives the expected results for blockdiagonal
50-
emptyrows = setdiff(1:bm, browIs)
51-
emptycols = setdiff(1:bn, bcolIs)
52-
for (row, col) in zip(emptyrows, emptycols)
53-
len = minimum(length, (brows[row], bcols[col]))
54-
r_axes[col] = brows[row][Base.OneTo(len)]
55-
end
56-
28+
# using the property that zip stops as soon as one of the iterators is exhausted
29+
r_axes = map(splat(infimum), zip(brows, bcols))
5730
r_axis = mortar_axis(r_axes)
58-
Q, R = similar_output(qr_compact!, A, r_axis, alg)
59-
60-
# allocate output
61-
for bI in eachblockstoredindex(A)
62-
brow, bcol = Tuple(bI)
63-
block = @view!(A[bI])
64-
block_alg = block_algorithm(alg, block)
65-
Q[brow, bcol], R[bcol, bcol] = MatrixAlgebraKit.initialize_output(
66-
qr_compact!, block, block_alg
67-
)
68-
end
6931

70-
# allocate output for blocks that aren't present -- do we also fill identities here?
71-
for (row, col) in zip(emptyrows, emptycols)
72-
@view!(Q[Block(row, col)])
73-
end
32+
BQ, BR = fieldtypes(output_type(qr_compact!, blocktype(A)))
33+
Q = similar(A, BlockType(BQ), (axes(A, 1), r_axis))
34+
R = similar(A, BlockType(BR), (r_axis, axes(A, 2)))
7435

7536
return Q, R
7637
end
7738

7839
function MatrixAlgebraKit.initialize_output(
79-
::typeof(qr_full!), A::AbstractBlockSparseMatrix, alg::BlockPermutedDiagonalAlgorithm
40+
::typeof(qr_full!), ::AbstractBlockSparseMatrix, ::BlockPermutedDiagonalAlgorithm
8041
)
81-
bm, bn = blocksize(A)
82-
83-
brows = eachblockaxis(axes(A, 1))
84-
r_axes = copy(brows)
85-
86-
# fill in values for blocks that are present
87-
bIs = collect(eachblockstoredindex(A))
88-
browIs = Int.(first.(Tuple.(bIs)))
89-
bcolIs = Int.(last.(Tuple.(bIs)))
90-
for bI in eachblockstoredindex(A)
91-
row, col = Int.(Tuple(bI))
92-
r_axes[col] = brows[row]
93-
end
94-
95-
# fill in values for blocks that aren't present, pairing them in order of occurence
96-
# this is a convention, which at least gives the expected results for blockdiagonal
97-
emptyrows = setdiff(1:bm, browIs)
98-
emptycols = setdiff(1:bn, bcolIs)
99-
for (row, col) in zip(emptyrows, emptycols)
100-
r_axes[col] = brows[row]
101-
end
102-
for (i, k) in enumerate((length(emptycols) + 1):length(emptyrows))
103-
r_axes[bn + i] = brows[emptyrows[k]]
104-
end
105-
106-
r_axis = mortar_axis(r_axes)
107-
Q, R = similar_output(qr_full!, A, r_axis, alg)
108-
109-
# allocate output
110-
for bI in eachblockstoredindex(A)
111-
brow, bcol = Tuple(bI)
112-
block = @view!(A[bI])
113-
block_alg = block_algorithm(alg, block)
114-
Q[brow, bcol], R[bcol, bcol] = MatrixAlgebraKit.initialize_output(
115-
qr_full!, block, block_alg
116-
)
117-
end
118-
119-
# allocate output for blocks that aren't present -- do we also fill identities here?
120-
for (row, col) in zip(emptyrows, emptycols)
121-
@view!(Q[Block(row, col)])
122-
end
123-
# also handle extra rows/cols
124-
for (i, k) in enumerate((length(emptycols) + 1):length(emptyrows))
125-
@view!(Q[Block(emptyrows[k], bn + i)])
126-
end
127-
42+
return nothing
43+
end
44+
function MatrixAlgebraKit.initialize_output(
45+
::typeof(qr_full!), A::AbstractBlockSparseMatrix, alg::BlockDiagonalAlgorithm
46+
)
47+
BQ, BR = fieldtypes(output_type(qr_full!, blocktype(A)))
48+
Q = similar(A, BlockType(BQ), (axes(A, 1), axes(A, 1)))
49+
R = similar(A, BlockType(BR), (axes(A, 1), axes(A, 2)))
12850
return Q, R
12951
end
13052

13153
function MatrixAlgebraKit.check_input(
132-
::typeof(qr_compact!), A::AbstractBlockSparseMatrix, QR
54+
::typeof(qr_compact!), A::AbstractBlockSparseMatrix, QR, ::BlockPermutedDiagonalAlgorithm
55+
)
56+
@assert isblockpermuteddiagonal(A)
57+
return nothing
58+
end
59+
function MatrixAlgebraKit.check_input(
60+
::typeof(qr_compact!), A::AbstractBlockSparseMatrix, (Q, R), ::BlockDiagonalAlgorithm
13361
)
134-
Q, R = QR
13562
@assert isa(Q, AbstractBlockSparseMatrix) && isa(R, AbstractBlockSparseMatrix)
13663
@assert eltype(A) == eltype(Q) == eltype(R)
13764
@assert axes(A, 1) == axes(Q, 1) && axes(A, 2) == axes(R, 2)
13865
@assert axes(Q, 2) == axes(R, 1)
139-
66+
@assert isblockdiagonal(A)
14067
return nothing
14168
end
14269

143-
function MatrixAlgebraKit.check_input(::typeof(qr_full!), A::AbstractBlockSparseMatrix, QR)
144-
Q, R = QR
70+
function MatrixAlgebraKit.check_input(
71+
::typeof(qr_full!), A::AbstractBlockSparseMatrix, QR, ::BlockPermutedDiagonalAlgorithm
72+
)
73+
@assert isblockpermuteddiagonal(A)
74+
return nothing
75+
end
76+
function MatrixAlgebraKit.check_input(
77+
::typeof(qr_full!), A::AbstractBlockSparseMatrix, (Q, R), ::BlockDiagonalAlgorithm
78+
)
14579
@assert isa(Q, AbstractBlockSparseMatrix) && isa(R, AbstractBlockSparseMatrix)
14680
@assert eltype(A) == eltype(Q) == eltype(R)
14781
@assert axes(A, 1) == axes(Q, 1) && axes(A, 2) == axes(R, 2)
14882
@assert axes(Q, 2) == axes(R, 1)
149-
83+
@assert isblockdiagonal(A)
15084
return nothing
15185
end
15286

15387
function MatrixAlgebraKit.qr_compact!(
15488
A::AbstractBlockSparseMatrix, QR, alg::BlockPermutedDiagonalAlgorithm
15589
)
156-
MatrixAlgebraKit.check_input(qr_compact!, A, QR)
157-
Q, R = QR
90+
check_input(qr_compact!, A, QR, alg)
91+
Ad, transform_rows, transform_cols = blockdiagonalize(A)
92+
Qd, Rd = qr_compact!(Ad, BlockDiagonalAlgorithm(alg))
93+
Q = transform_rows(Qd)
94+
R = transform_cols(Rd)
95+
return Q, R
96+
end
15897

159-
# do decomposition on each block
160-
for bI in eachblockstoredindex(A)
161-
brow, bcol = Tuple(bI)
162-
qr = (@view!(Q[brow, bcol]), @view!(R[bcol, bcol]))
163-
block = @view!(A[bI])
164-
block_alg = block_algorithm(alg, block)
165-
qr′ = qr_compact!(block, qr, block_alg)
166-
@assert qr === qr′ "qr_compact! might not be in-place"
167-
end
98+
function MatrixAlgebraKit.qr_compact!(
99+
A::AbstractBlockSparseMatrix, (Q, R), alg::BlockDiagonalAlgorithm
100+
)
101+
MatrixAlgebraKit.check_input(qr_compact!, A, (Q, R), alg)
168102

169-
# fill in identities for blocks that aren't present
170-
bIs = collect(eachblockstoredindex(A))
171-
browIs = Int.(first.(Tuple.(bIs)))
172-
bcolIs = Int.(last.(Tuple.(bIs)))
173-
emptyrows = setdiff(1:blocksize(A, 1), browIs)
174-
emptycols = setdiff(1:blocksize(A, 2), bcolIs)
175-
# needs copyto! instead because size(::LinearAlgebra.I) doesn't work
176-
# Q[Block(row, col)] = LinearAlgebra.I
177-
for (row, col) in zip(emptyrows, emptycols)
178-
copyto!(@view!(Q[Block(row, col)]), LinearAlgebra.I)
103+
# do decomposition on each block
104+
for I in 1:min(blocksize(A)...)
105+
bI = Block(I, I)
106+
if isstored(blocks(A), CartesianIndex(I, I)) # TODO: isblockstored
107+
block = @view!(A[bI])
108+
block_alg = block_algorithm(alg, block)
109+
bQ, bR = qr_compact!(block, block_alg)
110+
Q[bI] = bQ
111+
R[bI] = bR
112+
else
113+
copyto!(@view!(Q[bI]), LinearAlgebra.I)
114+
end
179115
end
180116

181-
return QR
117+
return Q, R
182118
end
183119

184120
function MatrixAlgebraKit.qr_full!(
185121
A::AbstractBlockSparseMatrix, QR, alg::BlockPermutedDiagonalAlgorithm
186122
)
187-
MatrixAlgebraKit.check_input(qr_full!, A, QR)
188-
Q, R = QR
189-
190-
# do decomposition on each block
191-
for bI in eachblockstoredindex(A)
192-
brow, bcol = Tuple(bI)
193-
qr = (@view!(Q[brow, bcol]), @view!(R[bcol, bcol]))
194-
block = @view!(A[bI])
195-
block_alg = block_algorithm(alg, block)
196-
qr′ = qr_full!(block, qr, block_alg)
197-
@assert qr === qr′ "qr_full! might not be in-place"
198-
end
123+
check_input(qr_full!, A, QR, alg)
124+
Ad, transform_rows, transform_cols = blockdiagonalize(A)
125+
Qd, Rd = qr_full!(Ad, BlockDiagonalAlgorithm(alg))
126+
Q = transform_rows(Qd)
127+
R = transform_cols(Rd)
128+
return Q, R
129+
end
199130

200-
# fill in identities for blocks that aren't present
201-
bIs = collect(eachblockstoredindex(A))
202-
browIs = Int.(first.(Tuple.(bIs)))
203-
bcolIs = Int.(last.(Tuple.(bIs)))
204-
emptyrows = setdiff(1:blocksize(A, 1), browIs)
205-
emptycols = setdiff(1:blocksize(A, 2), bcolIs)
206-
# needs copyto! instead because size(::LinearAlgebra.I) doesn't work
207-
# Q[Block(row, col)] = LinearAlgebra.I
208-
for (row, col) in zip(emptyrows, emptycols)
209-
copyto!(@view!(Q[Block(row, col)]), LinearAlgebra.I)
210-
end
131+
function MatrixAlgebraKit.qr_full!(
132+
A::AbstractBlockSparseMatrix, (Q, R), alg::BlockDiagonalAlgorithm
133+
)
134+
MatrixAlgebraKit.check_input(qr_full!, A, (Q, R), alg)
211135

212-
# also handle extra rows/cols
213-
bn = blocksize(A, 2)
214-
for (i, k) in enumerate((length(emptycols) + 1):length(emptyrows))
215-
copyto!(@view!(Q[Block(emptyrows[k], bn + i)]), LinearAlgebra.I)
136+
for I in 1:min(blocksize(A)...)
137+
bI = Block(I, I)
138+
if isstored(blocks(A), CartesianIndex(I, I)) # TODO: isblockstored
139+
block = @view!(A[bI])
140+
block_alg = block_algorithm(alg, block)
141+
bQ, bR = qr_full!(block, block_alg)
142+
Q[bI] = bQ
143+
R[bI] = bR
144+
else
145+
copyto!(@view!(Q[bI]), LinearAlgebra.I)
146+
end
216147
end
217148

218-
return QR
149+
return Q, R
219150
end

0 commit comments

Comments
 (0)