Skip to content

Commit d926f52

Browse files
committed
Refactor LQ
1 parent cecf967 commit d926f52

File tree

1 file changed

+87
-158
lines changed

1 file changed

+87
-158
lines changed

src/factorizations/lq.jl

Lines changed: 87 additions & 158 deletions
Original file line numberDiff line numberDiff line change
@@ -8,211 +8,140 @@ function MatrixAlgebraKit.default_lq_algorithm(
88
end
99
end
1010

11-
function similar_output(
12-
::typeof(lq_compact!), A, L_axis, alg::MatrixAlgebraKit.AbstractAlgorithm
13-
)
14-
L = similar(A, axes(A, 1), L_axis)
15-
Q = similar(A, L_axis, axes(A, 2))
16-
return L, Q
11+
function output_type(
12+
f::Union{typeof(lq_compact!),typeof(lq_full!)}, A::Type{<:AbstractMatrix{T}}
13+
) where {T}
14+
LQ = Base.promote_op(f, A)
15+
return isconcretetype(LQ) ? LQ : Tuple{AbstractMatrix{T},AbstractMatrix{T}}
1716
end
1817

19-
function similar_output(
20-
::typeof(lq_full!), A, L_axis, alg::MatrixAlgebraKit.AbstractAlgorithm
18+
function MatrixAlgebraKit.initialize_output(
19+
::typeof(lq_compact!), ::AbstractBlockSparseMatrix, ::BlockPermutedDiagonalAlgorithm
2120
)
22-
L = similar(A, axes(A, 1), L_axis)
23-
Q = similar(A, L_axis, axes(A, 2))
24-
return L, Q
21+
return nothing
2522
end
26-
2723
function MatrixAlgebraKit.initialize_output(
28-
::typeof(lq_compact!), A::AbstractBlockSparseMatrix, alg::BlockPermutedDiagonalAlgorithm
24+
::typeof(lq_compact!), A::AbstractBlockSparseMatrix, alg::BlockDiagonalAlgorithm
2925
)
30-
bm, bn = blocksize(A)
31-
bmn = min(bm, bn)
32-
3326
brows = eachblockaxis(axes(A, 1))
3427
bcols = eachblockaxis(axes(A, 2))
35-
l_axes = similar(brows, bmn)
36-
37-
# fill in values for blocks that are present
38-
bIs = collect(eachblockstoredindex(A))
39-
browIs = Int.(first.(Tuple.(bIs)))
40-
bcolIs = Int.(last.(Tuple.(bIs)))
41-
for bI in eachblockstoredindex(A)
42-
row, col = Int.(Tuple(bI))
43-
len = minimum(length, (brows[row], bcols[col]))
44-
l_axes[row] = bcols[col][Base.OneTo(len)]
45-
end
46-
47-
# fill in values for blocks that aren't present, pairing them in order of occurence
48-
# this is a convention, which at least gives the expected results for blockdiagonal
49-
emptyrows = setdiff(1:bm, browIs)
50-
emptycols = setdiff(1:bn, bcolIs)
51-
for (row, col) in zip(emptyrows, emptycols)
52-
len = minimum(length, (brows[row], bcols[col]))
53-
l_axes[row] = bcols[col][Base.OneTo(len)]
54-
end
55-
28+
# using the property that zip stops as soon as one of the iterators is exhausted
29+
l_axes = map(splat(infimum), zip(brows, bcols))
5630
l_axis = mortar_axis(l_axes)
57-
L, Q = similar_output(lq_compact!, A, l_axis, alg)
5831

59-
# allocate output
60-
for bI in eachblockstoredindex(A)
61-
brow, bcol = Tuple(bI)
62-
block = @view!(A[bI])
63-
block_alg = block_algorithm(alg, block)
64-
L[brow, brow], Q[brow, bcol] = MatrixAlgebraKit.initialize_output(
65-
lq_compact!, block, block_alg
66-
)
67-
end
68-
69-
# allocate output for blocks that aren't present -- do we also fill identities here?
70-
for (row, col) in zip(emptyrows, emptycols)
71-
@view!(Q[Block(row, col)])
72-
end
32+
BL, BQ = fieldtypes(output_type(lq_compact!, blocktype(A)))
33+
L = similar(A, BlockType(BL), (axes(A, 1), l_axis))
34+
Q = similar(A, BlockType(BQ), (l_axis, axes(A, 2)))
7335

7436
return L, Q
7537
end
7638

7739
function MatrixAlgebraKit.initialize_output(
78-
::typeof(lq_full!), A::AbstractBlockSparseMatrix, alg::BlockPermutedDiagonalAlgorithm
40+
::typeof(lq_full!), ::AbstractBlockSparseMatrix, ::BlockPermutedDiagonalAlgorithm
7941
)
80-
bm, bn = blocksize(A)
81-
82-
bcols = eachblockaxis(axes(A, 2))
83-
l_axes = copy(bcols)
84-
85-
# fill in values for blocks that are present
86-
bIs = collect(eachblockstoredindex(A))
87-
browIs = Int.(first.(Tuple.(bIs)))
88-
bcolIs = Int.(last.(Tuple.(bIs)))
89-
for bI in eachblockstoredindex(A)
90-
row, col = Int.(Tuple(bI))
91-
l_axes[row] = bcols[col]
92-
end
93-
94-
# fill in values for blocks that aren't present, pairing them in order of occurence
95-
# this is a convention, which at least gives the expected results for blockdiagonal
96-
emptyrows = setdiff(1:bm, browIs)
97-
emptycols = setdiff(1:bn, bcolIs)
98-
for (row, col) in zip(emptyrows, emptycols)
99-
l_axes[row] = bcols[col]
100-
end
101-
for (i, k) in enumerate((length(emptycols) + 1):length(emptyrows))
102-
l_axes[bn + i] = bcols[emptycols[k]]
103-
end
104-
105-
l_axis = mortar_axis(l_axes)
106-
L, Q = similar_output(lq_full!, A, l_axis, alg)
107-
108-
# allocate output
109-
for bI in eachblockstoredindex(A)
110-
brow, bcol = Tuple(bI)
111-
block = @view!(A[bI])
112-
block_alg = block_algorithm(alg, block)
113-
L[brow, brow], Q[brow, bcol] = MatrixAlgebraKit.initialize_output(
114-
lq_full!, block, block_alg
115-
)
116-
end
117-
118-
# allocate output for blocks that aren't present -- do we also fill identities here?
119-
for (row, col) in zip(emptyrows, emptycols)
120-
@view!(Q[Block(row, col)])
121-
end
122-
# also handle extra rows/cols
123-
for (i, k) in enumerate((length(emptyrows) + 1):length(emptycols))
124-
@view!(Q[Block(bm + i, emptycols[k])])
125-
end
126-
42+
return nothing
43+
end
44+
function MatrixAlgebraKit.initialize_output(
45+
::typeof(lq_full!), A::AbstractBlockSparseMatrix, alg::BlockDiagonalAlgorithm
46+
)
47+
BL, BQ = fieldtypes(output_type(lq_full!, blocktype(A)))
48+
L = similar(A, BlockType(BL), (axes(A, 1), axes(A, 2)))
49+
Q = similar(A, BlockType(BQ), (axes(A, 2), axes(A, 2)))
12750
return L, Q
12851
end
12952

13053
function MatrixAlgebraKit.check_input(
131-
::typeof(lq_compact!), A::AbstractBlockSparseMatrix, LQ
54+
::typeof(lq_compact!), A::AbstractBlockSparseMatrix, LQ, ::BlockPermutedDiagonalAlgorithm
55+
)
56+
@assert isblockpermuteddiagonal(A)
57+
return nothing
58+
end
59+
function MatrixAlgebraKit.check_input(
60+
::typeof(lq_compact!), A::AbstractBlockSparseMatrix, (L, Q), ::BlockDiagonalAlgorithm
13261
)
133-
L, Q = LQ
13462
@assert isa(L, AbstractBlockSparseMatrix) && isa(Q, AbstractBlockSparseMatrix)
13563
@assert eltype(A) == eltype(L) == eltype(Q)
13664
@assert axes(A, 1) == axes(L, 1) && axes(A, 2) == axes(Q, 2)
13765
@assert axes(L, 2) == axes(Q, 1)
138-
66+
@assert isblockdiagonal(A)
13967
return nothing
14068
end
141-
142-
function MatrixAlgebraKit.check_input(::typeof(lq_full!), A::AbstractBlockSparseMatrix, LQ)
143-
L, Q = LQ
69+
function MatrixAlgebraKit.check_input(
70+
::typeof(lq_full!), A::AbstractBlockSparseMatrix, LQ, ::BlockPermutedDiagonalAlgorithm
71+
)
72+
@assert isblockpermuteddiagonal(A)
73+
return nothing
74+
end
75+
function MatrixAlgebraKit.check_input(
76+
::typeof(lq_full!), A::AbstractBlockSparseMatrix, (L, Q), ::BlockDiagonalAlgorithm
77+
)
14478
@assert isa(L, AbstractBlockSparseMatrix) && isa(Q, AbstractBlockSparseMatrix)
14579
@assert eltype(A) == eltype(L) == eltype(Q)
14680
@assert axes(A, 1) == axes(L, 1) && axes(A, 2) == axes(Q, 2)
14781
@assert axes(L, 2) == axes(Q, 1)
148-
82+
@assert isblockdiagonal(A)
14983
return nothing
15084
end
15185

15286
function MatrixAlgebraKit.lq_compact!(
15387
A::AbstractBlockSparseMatrix, LQ, alg::BlockPermutedDiagonalAlgorithm
15488
)
155-
MatrixAlgebraKit.check_input(lq_compact!, A, LQ)
156-
L, Q = LQ
89+
MatrixAlgebraKit.check_input(lq_compact!, A, LQ, alg)
90+
Ad, transform_rows, transform_cols = blockdiagonalize(A)
91+
Ld, Qd = lq_compact!(Ad, BlockDiagonalAlgorithm(alg))
92+
L = transform_rows(Ld)
93+
Q = transform_cols(Qd)
94+
return L, Q
95+
end
96+
function MatrixAlgebraKit.lq_compact!(
97+
A::AbstractBlockSparseMatrix, (L, Q), alg::BlockDiagonalAlgorithm
98+
)
99+
MatrixAlgebraKit.check_input(lq_compact!, A, (L, Q), alg)
157100

158101
# do decomposition on each block
159-
for bI in eachblockstoredindex(A)
160-
brow, bcol = Tuple(bI)
161-
lq = (@view!(L[brow, brow]), @view!(Q[brow, bcol]))
162-
block = @view!(A[bI])
163-
block_alg = block_algorithm(alg, block)
164-
lq′ = lq_compact!(block, lq, block_alg)
165-
@assert lq === lq′ "lq_compact! might not be in-place"
166-
end
167-
168-
# fill in identities for blocks that aren't present
169-
bIs = collect(eachblockstoredindex(A))
170-
browIs = Int.(first.(Tuple.(bIs)))
171-
bcolIs = Int.(last.(Tuple.(bIs)))
172-
emptyrows = setdiff(1:blocksize(A, 1), browIs)
173-
emptycols = setdiff(1:blocksize(A, 2), bcolIs)
174-
# needs copyto! instead because size(::LinearAlgebra.I) doesn't work
175-
# Q[Block(row, col)] = LinearAlgebra.I
176-
for (row, col) in zip(emptyrows, emptycols)
177-
copyto!(@view!(Q[Block(row, col)]), LinearAlgebra.I)
102+
for I in 1:min(blocksize(A)...)
103+
bI = Block(I, I)
104+
if isstored(blocks(A), CartesianIndex(I, I)) # TODO: isblockstored
105+
block = @view!(A[bI])
106+
block_alg = block_algorithm(alg, block)
107+
bL, bQ = lq_compact!(block, block_alg)
108+
L[bI] = bL
109+
Q[bI] = bQ
110+
else
111+
copyto!(@view!(Q[bI]), LinearAlgebra.I)
112+
end
178113
end
179114

180-
return LQ
115+
return L, Q
181116
end
182117

183118
function MatrixAlgebraKit.lq_full!(
184119
A::AbstractBlockSparseMatrix, LQ, alg::BlockPermutedDiagonalAlgorithm
185120
)
186-
MatrixAlgebraKit.check_input(lq_full!, A, LQ)
187-
L, Q = LQ
188-
189-
# do decomposition on each block
190-
for bI in eachblockstoredindex(A)
191-
brow, bcol = Tuple(bI)
192-
lq = (@view!(L[brow, brow]), @view!(Q[brow, bcol]))
193-
block = @view!(A[bI])
194-
block_alg = block_algorithm(alg, block)
195-
lq′ = lq_full!(block, lq, block_alg)
196-
@assert lq === lq′ "lq_full! might not be in-place"
197-
end
198-
199-
# fill in identities for blocks that aren't present
200-
bIs = collect(eachblockstoredindex(A))
201-
browIs = Int.(first.(Tuple.(bIs)))
202-
bcolIs = Int.(last.(Tuple.(bIs)))
203-
emptyrows = setdiff(1:blocksize(A, 1), browIs)
204-
emptycols = setdiff(1:blocksize(A, 2), bcolIs)
205-
# needs copyto! instead because size(::LinearAlgebra.I) doesn't work
206-
# Q[Block(row, col)] = LinearAlgebra.I
207-
for (row, col) in zip(emptyrows, emptycols)
208-
copyto!(@view!(Q[Block(row, col)]), LinearAlgebra.I)
209-
end
121+
MatrixAlgebraKit.check_input(lq_full!, A, LQ, alg)
122+
Ad, transform_rows, transform_cols = blockdiagonalize(A)
123+
Ld, Qd = lq_full!(Ad, BlockDiagonalAlgorithm(alg))
124+
L = transform_rows(Ld)
125+
Q = transform_cols(Qd)
126+
return L, Q
127+
end
128+
function MatrixAlgebraKit.lq_full!(
129+
A::AbstractBlockSparseMatrix, (L, Q), alg::BlockDiagonalAlgorithm
130+
)
131+
MatrixAlgebraKit.check_input(lq_full!, A, (L, Q), alg)
210132

211-
# also handle extra rows/cols
212-
bm = blocksize(A, 1)
213-
for (i, k) in enumerate((length(emptyrows) + 1):length(emptycols))
214-
copyto!(@view!(Q[Block(bm + i, emptycols[k])]), LinearAlgebra.I)
133+
for I in 1:min(blocksize(A)...)
134+
bI = Block(I, I)
135+
if isstored(blocks(A), CartesianIndex(I, I)) # TODO: isblockstored
136+
block = @view!(A[bI])
137+
block_alg = block_algorithm(alg, block)
138+
bL, bQ = lq_full!(block, block_alg)
139+
L[bI] = bL
140+
Q[bI] = bQ
141+
else
142+
copyto!(@view!(Q[bI]), LinearAlgebra.I)
143+
end
215144
end
216145

217-
return LQ
146+
return L, Q
218147
end

0 commit comments

Comments
 (0)