@@ -9,211 +9,143 @@ function MatrixAlgebraKit.default_qr_algorithm(
9
9
end
10
10
end
11
11
12
- function similar_output (
13
- :: typeof (qr_compact!), A, R_axis, alg:: MatrixAlgebraKit.AbstractAlgorithm
14
- )
15
- Q = similar (A, axes (A, 1 ), R_axis)
16
- R = similar (A, R_axis, axes (A, 2 ))
17
- return Q, R
12
+ function output_type (
13
+ f:: Union{typeof(qr_compact!),typeof(qr_full!)} , A:: Type{<:AbstractMatrix{T}}
14
+ ) where {T}
15
+ QR = Base. promote_op (f, A)
16
+ return isconcretetype (QR) ? QR : Tuple{AbstractMatrix{T},AbstractMatrix{T}}
18
17
end
19
18
20
- function similar_output (
21
- :: typeof (qr_full !), A, R_axis, alg :: MatrixAlgebraKit.AbstractAlgorithm
19
+ function MatrixAlgebraKit . initialize_output (
20
+ :: typeof (qr_compact !), :: AbstractBlockSparseMatrix , :: BlockPermutedDiagonalAlgorithm
22
21
)
23
- Q = similar (A, axes (A, 1 ), R_axis)
24
- R = similar (A, R_axis, axes (A, 2 ))
25
- return Q, R
22
+ return nothing
26
23
end
27
-
28
24
function MatrixAlgebraKit. initialize_output (
29
- :: typeof (qr_compact!), A:: AbstractBlockSparseMatrix , alg:: BlockPermutedDiagonalAlgorithm
25
+ :: typeof (qr_compact!), A:: AbstractBlockSparseMatrix , alg:: BlockDiagonalAlgorithm
30
26
)
31
- bm, bn = blocksize (A)
32
- bmn = min (bm, bn)
33
-
34
27
brows = eachblockaxis (axes (A, 1 ))
35
28
bcols = eachblockaxis (axes (A, 2 ))
36
- r_axes = similar (brows, bmn)
37
-
38
- # fill in values for blocks that are present
39
- bIs = collect (eachblockstoredindex (A))
40
- browIs = Int .(first .(Tuple .(bIs)))
41
- bcolIs = Int .(last .(Tuple .(bIs)))
42
- for bI in eachblockstoredindex (A)
43
- row, col = Int .(Tuple (bI))
44
- len = minimum (length, (brows[row], bcols[col]))
45
- r_axes[col] = brows[row][Base. OneTo (len)]
46
- end
47
-
48
- # fill in values for blocks that aren't present, pairing them in order of occurence
49
- # this is a convention, which at least gives the expected results for blockdiagonal
50
- emptyrows = setdiff (1 : bm, browIs)
51
- emptycols = setdiff (1 : bn, bcolIs)
52
- for (row, col) in zip (emptyrows, emptycols)
53
- len = minimum (length, (brows[row], bcols[col]))
54
- r_axes[col] = brows[row][Base. OneTo (len)]
55
- end
56
-
29
+ # using the property that zip stops as soon as one of the iterators is exhausted
30
+ r_axes = map (splat (infimum), zip (brows, bcols))
57
31
r_axis = mortar_axis (r_axes)
58
- Q, R = similar_output (qr_compact!, A, r_axis, alg)
59
-
60
- # allocate output
61
- for bI in eachblockstoredindex (A)
62
- brow, bcol = Tuple (bI)
63
- block = @view! (A[bI])
64
- block_alg = block_algorithm (alg, block)
65
- Q[brow, bcol], R[bcol, bcol] = MatrixAlgebraKit. initialize_output (
66
- qr_compact!, block, block_alg
67
- )
68
- end
69
32
70
- # allocate output for blocks that aren't present -- do we also fill identities here?
71
- for (row, col) in zip (emptyrows, emptycols)
72
- @view! (Q[Block (row, col)])
73
- end
33
+ BQ, BR = fieldtypes (output_type (qr_compact!, blocktype (A)))
34
+ Q = similar (A, BlockType (BQ), (axes (A, 1 ), r_axis))
35
+ R = similar (A, BlockType (BR), (r_axis, axes (A, 2 )))
74
36
75
37
return Q, R
76
38
end
77
39
78
40
function MatrixAlgebraKit. initialize_output (
79
- :: typeof (qr_full!), A :: AbstractBlockSparseMatrix , alg :: BlockPermutedDiagonalAlgorithm
41
+ :: typeof (qr_full!), :: AbstractBlockSparseMatrix , :: BlockPermutedDiagonalAlgorithm
80
42
)
81
- bm, bn = blocksize (A)
82
-
83
- brows = eachblockaxis (axes (A, 1 ))
84
- r_axes = copy (brows)
85
-
86
- # fill in values for blocks that are present
87
- bIs = collect (eachblockstoredindex (A))
88
- browIs = Int .(first .(Tuple .(bIs)))
89
- bcolIs = Int .(last .(Tuple .(bIs)))
90
- for bI in eachblockstoredindex (A)
91
- row, col = Int .(Tuple (bI))
92
- r_axes[col] = brows[row]
93
- end
94
-
95
- # fill in values for blocks that aren't present, pairing them in order of occurence
96
- # this is a convention, which at least gives the expected results for blockdiagonal
97
- emptyrows = setdiff (1 : bm, browIs)
98
- emptycols = setdiff (1 : bn, bcolIs)
99
- for (row, col) in zip (emptyrows, emptycols)
100
- r_axes[col] = brows[row]
101
- end
102
- for (i, k) in enumerate ((length (emptycols) + 1 ): length (emptyrows))
103
- r_axes[bn + i] = brows[emptyrows[k]]
104
- end
105
-
106
- r_axis = mortar_axis (r_axes)
107
- Q, R = similar_output (qr_full!, A, r_axis, alg)
108
-
109
- # allocate output
110
- for bI in eachblockstoredindex (A)
111
- brow, bcol = Tuple (bI)
112
- block = @view! (A[bI])
113
- block_alg = block_algorithm (alg, block)
114
- Q[brow, bcol], R[bcol, bcol] = MatrixAlgebraKit. initialize_output (
115
- qr_full!, block, block_alg
116
- )
117
- end
118
-
119
- # allocate output for blocks that aren't present -- do we also fill identities here?
120
- for (row, col) in zip (emptyrows, emptycols)
121
- @view! (Q[Block (row, col)])
122
- end
123
- # also handle extra rows/cols
124
- for (i, k) in enumerate ((length (emptycols) + 1 ): length (emptyrows))
125
- @view! (Q[Block (emptyrows[k], bn + i)])
126
- end
127
-
43
+ return nothing
44
+ end
45
+ function MatrixAlgebraKit. initialize_output (
46
+ :: typeof (qr_full!), A:: AbstractBlockSparseMatrix , alg:: BlockDiagonalAlgorithm
47
+ )
48
+ BQ, BR = fieldtypes (output_type (qr_compact!, blocktype (A)))
49
+ Q = similar (A, BlockType (BQ), (axes (A, 1 ), axes (A, 1 )))
50
+ R = similar (A, BlockType (BR), (axes (A, 1 ), axes (A, 2 )))
128
51
return Q, R
129
52
end
130
53
131
54
function MatrixAlgebraKit. check_input (
132
- :: typeof (qr_compact!), A:: AbstractBlockSparseMatrix , QR
55
+ :: typeof (qr_compact!), A:: AbstractBlockSparseMatrix , QR, :: BlockPermutedDiagonalAlgorithm
56
+ )
57
+ @assert isblockpermuteddiagonal (A)
58
+ return nothing
59
+ end
60
+ function MatrixAlgebraKit. check_input (
61
+ :: typeof (qr_compact!), A:: AbstractBlockSparseMatrix , (Q, R), :: BlockDiagonalAlgorithm
133
62
)
134
- Q, R = QR
135
63
@assert isa (Q, AbstractBlockSparseMatrix) && isa (R, AbstractBlockSparseMatrix)
136
64
@assert eltype (A) == eltype (Q) == eltype (R)
137
65
@assert axes (A, 1 ) == axes (Q, 1 ) && axes (A, 2 ) == axes (R, 2 )
138
66
@assert axes (Q, 2 ) == axes (R, 1 )
139
-
67
+ @assert isblockdiagonal (A)
140
68
return nothing
141
69
end
142
70
143
- function MatrixAlgebraKit. check_input (:: typeof (qr_full!), A:: AbstractBlockSparseMatrix , QR)
144
- Q, R = QR
71
+ function MatrixAlgebraKit. check_input (
72
+ :: typeof (qr_full!), A:: AbstractBlockSparseMatrix , QR, :: BlockPermutedDiagonalAlgorithm
73
+ )
74
+ @assert isblockpermuteddiagonal (A)
75
+ return nothing
76
+ end
77
+ function MatrixAlgebraKit. check_input (
78
+ :: typeof (qr_full!), A:: AbstractBlockSparseMatrix , (Q, R), :: BlockDiagonalAlgorithm
79
+ )
145
80
@assert isa (Q, AbstractBlockSparseMatrix) && isa (R, AbstractBlockSparseMatrix)
146
81
@assert eltype (A) == eltype (Q) == eltype (R)
147
82
@assert axes (A, 1 ) == axes (Q, 1 ) && axes (A, 2 ) == axes (R, 2 )
148
83
@assert axes (Q, 2 ) == axes (R, 1 )
149
-
84
+ @assert isblockdiagonal (A)
150
85
return nothing
151
86
end
152
87
153
88
function MatrixAlgebraKit. qr_compact! (
154
89
A:: AbstractBlockSparseMatrix , QR, alg:: BlockPermutedDiagonalAlgorithm
155
90
)
156
- MatrixAlgebraKit. check_input (qr_compact!, A, QR)
157
- Q, R = QR
91
+ check_input (qr_compact!, A, QR, alg)
92
+ Ad, transform_rows, transform_cols = blockdiagonalize (A)
93
+ Qd, Rd = qr_compact! (Ad, BlockDiagonalAlgorithm (alg))
94
+ Q = transform_rows (Qd)
95
+ R = transform_cols (Rd)
96
+ return Q, R
97
+ end
158
98
159
- # do decomposition on each block
160
- for bI in eachblockstoredindex (A)
161
- brow, bcol = Tuple (bI)
162
- qr = (@view! (Q[brow, bcol]), @view! (R[bcol, bcol]))
163
- block = @view! (A[bI])
164
- block_alg = block_algorithm (alg, block)
165
- qr′ = qr_compact! (block, qr, block_alg)
166
- @assert qr === qr′ " qr_compact! might not be in-place"
167
- end
99
+ function MatrixAlgebraKit. qr_compact! (
100
+ A:: AbstractBlockSparseMatrix , (Q, R), alg:: BlockDiagonalAlgorithm
101
+ )
102
+ MatrixAlgebraKit. check_input (qr_compact!, A, (Q, R), alg)
168
103
169
- # fill in identities for blocks that aren't present
170
- bIs = collect (eachblockstoredindex (A))
171
- browIs = Int .(first .(Tuple .(bIs)))
172
- bcolIs = Int .(last .(Tuple .(bIs)))
173
- emptyrows = setdiff (1 : blocksize (A, 1 ), browIs)
174
- emptycols = setdiff (1 : blocksize (A, 2 ), bcolIs)
175
- # needs copyto! instead because size(::LinearAlgebra.I) doesn't work
176
- # Q[Block(row, col)] = LinearAlgebra.I
177
- for (row, col) in zip (emptyrows, emptycols)
178
- copyto! (@view! (Q[Block (row, col)]), LinearAlgebra. I)
104
+ # do decomposition on each block
105
+ for I in 1 : min (blocksize (A)... )
106
+ bI = Block (I, I)
107
+ if isstored (blocks (A), CartesianIndex (I, I)) # TODO : isblockstored
108
+ block = @view! (A[bI])
109
+ block_alg = block_algorithm (alg, block)
110
+ bQ, bR = qr_compact! (block, block_alg)
111
+ Q[bI] = bQ
112
+ R[bI] = bR
113
+ else
114
+ copyto! (@view! (Q[bI]), LinearAlgebra. I)
115
+ end
179
116
end
180
117
181
- return QR
118
+ return Q, R
182
119
end
183
120
184
121
function MatrixAlgebraKit. qr_full! (
185
122
A:: AbstractBlockSparseMatrix , QR, alg:: BlockPermutedDiagonalAlgorithm
186
123
)
187
- MatrixAlgebraKit. check_input (qr_full!, A, QR)
188
- Q, R = QR
189
-
190
- # do decomposition on each block
191
- for bI in eachblockstoredindex (A)
192
- brow, bcol = Tuple (bI)
193
- qr = (@view! (Q[brow, bcol]), @view! (R[bcol, bcol]))
194
- block = @view! (A[bI])
195
- block_alg = block_algorithm (alg, block)
196
- qr′ = qr_full! (block, qr, block_alg)
197
- @assert qr === qr′ " qr_full! might not be in-place"
198
- end
124
+ check_input (qr_full!, A, QR, alg)
125
+ Ad, transform_rows, transform_cols = blockdiagonalize (A)
126
+ Qd, Rd = qr_full! (Ad, BlockDiagonalAlgorithm (alg))
127
+ Q = transform_rows (Qd)
128
+ R = transform_cols (Rd)
129
+ return Q, R
130
+ end
199
131
200
- # fill in identities for blocks that aren't present
201
- bIs = collect (eachblockstoredindex (A))
202
- browIs = Int .(first .(Tuple .(bIs)))
203
- bcolIs = Int .(last .(Tuple .(bIs)))
204
- emptyrows = setdiff (1 : blocksize (A, 1 ), browIs)
205
- emptycols = setdiff (1 : blocksize (A, 2 ), bcolIs)
206
- # needs copyto! instead because size(::LinearAlgebra.I) doesn't work
207
- # Q[Block(row, col)] = LinearAlgebra.I
208
- for (row, col) in zip (emptyrows, emptycols)
209
- copyto! (@view! (Q[Block (row, col)]), LinearAlgebra. I)
210
- end
132
+ function MatrixAlgebraKit. qr_full! (
133
+ A:: AbstractBlockSparseMatrix , (Q, R), alg:: BlockDiagonalAlgorithm
134
+ )
135
+ MatrixAlgebraKit. check_input (qr_full!, A, (Q, R), alg)
211
136
212
- # also handle extra rows/cols
213
- bn = blocksize (A, 2 )
214
- for (i, k) in enumerate ((length (emptycols) + 1 ): length (emptyrows))
215
- copyto! (@view! (Q[Block (emptyrows[k], bn + i)]), LinearAlgebra. I)
137
+ for I in 1 : min (blocksize (A)... )
138
+ bI = Block (I, I)
139
+ if isstored (blocks (A), CartesianIndex (I, I)) # TODO : isblockstored
140
+ block = @view! (A[bI])
141
+ block_alg = block_algorithm (alg, block)
142
+ bQ, bR = qr_full! (block, block_alg)
143
+ Q[bI] = bQ
144
+ R[bI] = bR
145
+ else
146
+ copyto! (@view! (Q[bI]), LinearAlgebra. I)
147
+ end
216
148
end
217
149
218
- return QR
150
+ return Q, R
219
151
end
0 commit comments