@@ -112,55 +112,51 @@ function MatrixAlgebraKit.initialize_output(
112
112
bm, bn = blocksize (A)
113
113
114
114
brows = eachblockaxis (axes (A, 1 ))
115
- u_axes = similar (brows)
115
+ bcols = eachblockaxis (axes (A, 2 ))
116
+ u_axes = similar (brows, bm)
117
+ v_axes = similar (bcols, bn)
116
118
117
119
# fill in values for blocks that are present
118
- bIs = collect (eachblockstoredindex (A))
120
+ bIs = sort! ( collect (eachblockstoredindex (A)), by = Int ∘ last ∘ Tuple )
119
121
browIs = Int .(first .(Tuple .(bIs)))
120
122
bcolIs = Int .(last .(Tuple .(bIs)))
121
- for bI in eachblockstoredindex (A )
123
+ for (I, bI) in enumerate (bIs )
122
124
row, col = Int .(Tuple (bI))
123
- u_axes[col] = brows[row]
125
+ u_axes[I] = brows[row]
126
+ v_axes[I] = bcols[col]
124
127
end
125
128
126
129
# fill in values for blocks that aren't present, pairing them in order of occurence
127
130
# this is a convention, which at least gives the expected results for blockdiagonal
128
131
emptyrows = setdiff (1 : bm, browIs)
132
+ u_axes[length (bIs) .+ (1 : length (emptyrows))] .= brows[emptyrows]
129
133
emptycols = setdiff (1 : bn, bcolIs)
130
- for (row, col) in zip (emptyrows, emptycols)
131
- u_axes[col] = brows[row]
132
- end
133
- for (i, k) in enumerate ((length (emptycols) + 1 ): length (emptyrows))
134
- u_axes[bn + i] = brows[emptyrows[k]]
135
- end
136
-
134
+ v_axes[length (bIs) .+ (1 : length (emptycols))] .= bcols[emptycols]
135
+
137
136
u_axis = mortar_axis (u_axes)
138
- S_axes = (u_axis, axes (A, 2 ))
137
+ v_axis = mortar_axis (@show v_axes)
138
+ S_axes = (u_axis, v_axis)
139
139
U, S, Vt = similar_output (svd_full!, A, S_axes, alg)
140
140
141
141
# allocate output
142
- for bI in eachblockstoredindex (A )
142
+ for (I, bI) in enumerate (bIs )
143
143
brow, bcol = Tuple (bI)
144
+ bcol′ = Block (I)
144
145
block = @view! (A[bI])
145
146
block_alg = block_algorithm (alg, block)
146
- U[brow, bcol], S[bcol, bcol], Vt[bcol, bcol] = MatrixAlgebraKit. initialize_output (
147
+ U[brow, bcol′ ], S[bcol′ , bcol′ ], Vt[bcol′ , bcol] = MatrixAlgebraKit. initialize_output (
147
148
svd_full!, block, block_alg
148
149
)
149
150
end
150
151
151
152
# allocate output for blocks that aren't present -- do we also fill identities here?
152
- for (row, col) in zip (emptyrows, emptycols)
153
- @view! (U[Block (row, col)])
154
- @view! (Vt[Block (col, col)])
153
+ for (I, row) in enumerate (emptyrows)
154
+ @view! (U[Block (row, I)])
155
155
end
156
- # also handle extra rows/cols
157
- for i in (length (emptyrows) + 1 ): length (emptycols)
158
- @view! (Vt[Block (emptycols[i], emptycols[i])])
156
+ for (I, col) in enumerate (emptycols)
157
+ @view! (Vt[Block (I, col)])
159
158
end
160
- for (i, k) in enumerate ((length (emptycols) + 1 ): length (emptyrows))
161
- @view! (U[Block (emptyrows[k], bn + i)])
162
- end
163
-
159
+
164
160
return U, S, Vt
165
161
end
166
162
@@ -185,8 +181,7 @@ function MatrixAlgebraKit.check_input(
185
181
isa (Vᴴ, AbstractBlockSparseMatrix)
186
182
@assert eltype (A) == eltype (U) == eltype (Vᴴ)
187
183
@assert real (eltype (A)) == eltype (S)
188
- @assert axes (A, 1 ) == axes (U, 1 ) && axes (A, 2 ) == axes (Vᴴ, 1 ) == axes (Vᴴ, 2 )
189
- @assert axes (S, 2 ) == axes (A, 2 )
184
+ @assert axes (A, 1 ) == axes (U, 1 ) && axes (A, 2 ) == axes (Vᴴ, 2 )
190
185
return nothing
191
186
end
192
187
@@ -208,7 +203,6 @@ function MatrixAlgebraKit.svd_compact!(
208
203
end
209
204
210
205
# fill in identities for blocks that aren't present
211
- bIs = collect (eachblockstoredindex (A))
212
206
browIs = Int .(first .(Tuple .(bIs)))
213
207
bcolIs = Int .(last .(Tuple .(bIs)))
214
208
emptyrows = setdiff (1 : blocksize (A, 1 ), browIs)
@@ -230,36 +224,30 @@ function MatrixAlgebraKit.svd_full!(
230
224
check_input (svd_full!, A, (U, S, Vᴴ))
231
225
232
226
# do decomposition on each block
233
- for bI in eachblockstoredindex (A)
227
+ bIs = sort! (collect (eachblockstoredindex (A)); by= Int ∘ last ∘ Tuple)
228
+ for (I, bI) in enumerate (bIs)
234
229
brow, bcol = Tuple (bI)
235
- usvᴴ = (@view! (U[brow, bcol]), @view! (S[bcol, bcol]), @view! (Vᴴ[bcol, bcol]))
230
+ bcol′ = Block (I)
231
+ usvᴴ = (@view! (U[brow, bcol′]), @view! (S[bcol′, bcol′]), @view! (Vᴴ[bcol′, bcol]))
236
232
block = @view! (A[bI])
237
233
block_alg = block_algorithm (alg, block)
238
234
usvᴴ′ = svd_full! (block, usvᴴ, block_alg)
239
235
@assert usvᴴ === usvᴴ′ " svd_full! might not be in-place"
240
236
end
241
237
242
238
# fill in identities for blocks that aren't present
243
- bIs = collect (eachblockstoredindex (A))
244
239
browIs = Int .(first .(Tuple .(bIs)))
245
240
bcolIs = Int .(last .(Tuple .(bIs)))
246
241
emptyrows = setdiff (1 : blocksize (A, 1 ), browIs)
247
242
emptycols = setdiff (1 : blocksize (A, 2 ), bcolIs)
248
243
# needs copyto! instead because size(::LinearAlgebra.I) doesn't work
249
244
# U[Block(row, col)] = LinearAlgebra.I
250
245
# Vt[Block(col, col)] = LinearAlgebra.I
251
- for (row, col) in zip (emptyrows, emptycols)
252
- copyto! (@view! (U[Block (row, col)]), LinearAlgebra. I)
253
- copyto! (@view! (Vᴴ[Block (col, col)]), LinearAlgebra. I)
254
- end
255
-
256
- # also handle extra rows/cols
257
- for i in (length (emptyrows) + 1 ): length (emptycols)
258
- copyto! (@view! (Vᴴ[Block (emptycols[i], emptycols[i])]), LinearAlgebra. I)
246
+ for (I, row) in enumerate (emptyrows)
247
+ copyto! (@view! (U[Block (row, length (bIs) + I)]), LinearAlgebra. I)
259
248
end
260
- bn = blocksize (A, 2 )
261
- for (i, k) in enumerate ((length (emptycols) + 1 ): length (emptyrows))
262
- copyto! (@view! (U[Block (emptyrows[k], bn + i)]), LinearAlgebra. I)
249
+ for (I, col) in enumerate (emptycols)
250
+ copyto! (@view! (Vᴴ[Block (length (bIs) + I, col)]), LinearAlgebra. I)
263
251
end
264
252
265
253
return (U, S, Vᴴ)
0 commit comments