@@ -29,59 +29,34 @@ function similar_output(
29
29
end
30
30
31
31
function MatrixAlgebraKit. initialize_output (
32
- :: typeof (svd_compact!), A:: AbstractBlockSparseMatrix , alg:: BlockPermutedDiagonalAlgorithm
32
+ :: typeof (svd_compact!), :: AbstractBlockSparseMatrix , :: BlockPermutedDiagonalAlgorithm
33
+ )
34
+ return nothing
35
+ end
36
+ function MatrixAlgebraKit. initialize_output (
37
+ :: typeof (svd_compact!), A:: AbstractBlockSparseMatrix , alg:: BlockDiagonalAlgorithm
33
38
)
34
- bm, bn = blocksize (A)
35
- bmn = min (bm, bn)
36
-
37
39
brows = eachblockaxis (axes (A, 1 ))
38
40
bcols = eachblockaxis (axes (A, 2 ))
39
- u_axes = similar (brows, bmn)
40
- v_axes = similar (brows, bmn)
41
+ # using the property that zip stops as soon as one of the iterators is exhausted
42
+ s_axes = map (splat (infimum), zip (brows, bcols))
43
+ s_axis = mortar_axis (s_axes)
44
+ S_axes = (s_axis, s_axis)
45
+ U, S, Vᴴ = similar_output (svd_compact!, A, S_axes, alg)
41
46
42
- # fill in values for blocks that are present
43
- bIs = collect (eachblockstoredindex (A))
44
- browIs = Int .(first .(Tuple .(bIs)))
45
- bcolIs = Int .(last .(Tuple .(bIs)))
46
47
for bI in eachblockstoredindex (A)
47
- row, col = Int .(Tuple (bI))
48
- u_axes[col] = infimum (brows[row], bcols[col])
49
- v_axes[col] = infimum (bcols[col], brows[row])
50
- end
51
-
52
- # fill in values for blocks that aren't present, pairing them in order of occurence
53
- # this is a convention, which at least gives the expected results for blockdiagonal
54
- emptyrows = setdiff (1 : bm, browIs)
55
- emptycols = setdiff (1 : bn, bcolIs)
56
- for (row, col) in zip (emptyrows, emptycols)
57
- u_axes[col] = infimum (brows[row], bcols[col])
58
- v_axes[col] = infimum (bcols[col], brows[row])
59
- end
60
-
61
- u_axis = mortar_axis (u_axes)
62
- v_axis = mortar_axis (v_axes)
63
- S_axes = (u_axis, v_axis)
64
- U, S, Vt = similar_output (svd_compact!, A, S_axes, alg)
65
-
66
- # allocate output
67
- for bI in eachblockstoredindex (A)
68
- brow, bcol = Tuple (bI)
69
48
block = @view! (A[bI])
70
49
block_alg = block_algorithm (alg, block)
71
- U[brow, bcol], S[bcol, bcol], Vt[bcol, bcol] = MatrixAlgebraKit. initialize_output (
50
+ I = first (Tuple (bI)) # == last(Tuple(bI))
51
+ U[I, I], S[I, I], Vᴴ[I, I] = MatrixAlgebraKit. initialize_output (
72
52
svd_compact!, block, block_alg
73
53
)
74
54
end
75
55
76
- # allocate output for blocks that aren't present -- do we also fill identities here?
77
- for (row, col) in zip (emptyrows, emptycols)
78
- @view! (U[Block (row, col)])
79
- @view! (Vt[Block (col, col)])
80
- end
81
-
82
- return U, S, Vt
56
+ return U, S, Vᴴ
83
57
end
84
58
59
+
85
60
function similar_output (
86
61
:: typeof (svd_full!), A, S_axes, alg:: MatrixAlgebraKit.AbstractAlgorithm
87
62
)
@@ -93,65 +68,39 @@ function similar_output(
93
68
end
94
69
95
70
function MatrixAlgebraKit. initialize_output (
96
- :: typeof (svd_full!), A :: AbstractBlockSparseMatrix , alg :: BlockPermutedDiagonalAlgorithm
71
+ :: typeof (svd_full!), :: AbstractBlockSparseMatrix , :: BlockPermutedDiagonalAlgorithm
97
72
)
98
- bm, bn = blocksize (A)
99
-
100
- brows = eachblockaxis (axes (A, 1 ))
101
- u_axes = similar (brows)
102
-
103
- # fill in values for blocks that are present
104
- bIs = collect (eachblockstoredindex (A))
105
- browIs = Int .(first .(Tuple .(bIs)))
106
- bcolIs = Int .(last .(Tuple .(bIs)))
107
- for bI in eachblockstoredindex (A)
108
- row, col = Int .(Tuple (bI))
109
- u_axes[col] = brows[row]
110
- end
111
-
112
- # fill in values for blocks that aren't present, pairing them in order of occurence
113
- # this is a convention, which at least gives the expected results for blockdiagonal
114
- emptyrows = setdiff (1 : bm, browIs)
115
- emptycols = setdiff (1 : bn, bcolIs)
116
- for (row, col) in zip (emptyrows, emptycols)
117
- u_axes[col] = brows[row]
118
- end
119
- for (i, k) in enumerate ((length (emptycols) + 1 ): length (emptyrows))
120
- u_axes[bn + i] = brows[emptyrows[k]]
121
- end
73
+ return nothing
74
+ end
122
75
123
- u_axis = mortar_axis (u_axes)
124
- S_axes = (u_axis, axes (A, 2 ))
125
- U, S, Vt = similar_output (svd_full!, A, S_axes, alg)
76
+ function MatrixAlgebraKit. initialize_output (
77
+ :: typeof (svd_full!), A:: AbstractBlockSparseMatrix , alg:: BlockDiagonalAlgorithm
78
+ )
79
+ U, S, Vᴴ = similar_output (svd_full!, A, axes (A), alg)
126
80
127
- # allocate output
128
81
for bI in eachblockstoredindex (A)
129
- brow, bcol = Tuple (bI)
130
82
block = @view! (A[bI])
131
83
block_alg = block_algorithm (alg, block)
132
- U[brow, bcol], S[bcol, bcol], Vt[bcol, bcol] = MatrixAlgebraKit. initialize_output (
84
+ I = first (Tuple (bI)) # == last(Tuple(bI))
85
+ U[I, I], S[I, I], Vᴴ[I, I] = MatrixAlgebraKit. initialize_output (
133
86
svd_full!, block, block_alg
134
87
)
135
88
end
136
89
137
- # allocate output for blocks that aren't present -- do we also fill identities here?
138
- for (row, col) in zip (emptyrows, emptycols)
139
- @view! (U[Block (row, col)])
140
- @view! (Vt[Block (col, col)])
141
- end
142
- # also handle extra rows/cols
143
- for i in (length (emptyrows) + 1 ): length (emptycols)
144
- @view! (Vt[Block (emptycols[i], emptycols[i])])
145
- end
146
- for (i, k) in enumerate ((length (emptycols) + 1 ): length (emptyrows))
147
- @view! (U[Block (emptyrows[k], bn + i)])
148
- end
90
+ return U, S, Vᴴ
91
+ end
149
92
150
- return U, S, Vt
93
+ function MatrixAlgebraKit. check_input (
94
+ :: typeof (svd_compact!),
95
+ A:: AbstractBlockSparseMatrix ,
96
+ USVᴴ,
97
+ :: BlockPermutedDiagonalAlgorithm ,
98
+ )
99
+ @assert isblockpermuteddiagonal (A)
151
100
end
152
101
153
102
function MatrixAlgebraKit. check_input (
154
- :: typeof (svd_compact!), A:: AbstractBlockSparseMatrix , (U, S, Vᴴ)
103
+ :: typeof (svd_compact!), A:: AbstractBlockSparseMatrix , (U, S, Vᴴ), :: BlockDiagonalAlgorithm
155
104
)
156
105
@assert isa (U, AbstractBlockSparseMatrix) &&
157
106
isa (S, AbstractBlockSparseMatrix) &&
@@ -160,11 +109,19 @@ function MatrixAlgebraKit.check_input(
160
109
@assert real (eltype (A)) == eltype (S)
161
110
@assert axes (A, 1 ) == axes (U, 1 ) && axes (A, 2 ) == axes (Vᴴ, 2 )
162
111
@assert axes (S, 1 ) == axes (S, 2 )
112
+ @assert isblockdiagonal (A)
163
113
return nothing
164
114
end
165
115
166
116
function MatrixAlgebraKit. check_input (
167
- :: typeof (svd_full!), A:: AbstractBlockSparseMatrix , (U, S, Vᴴ)
117
+ :: typeof (svd_full!), A:: AbstractBlockSparseMatrix , USVᴴ, :: BlockPermutedDiagonalAlgorithm
118
+ )
119
+ @assert isblockpermuteddiagonal (A)
120
+ return nothing
121
+ end
122
+
123
+ function MatrixAlgebraKit. check_input (
124
+ :: typeof (svd_full!), A:: AbstractBlockSparseMatrix , (U, S, Vᴴ), :: BlockDiagonalAlgorithm
168
125
)
169
126
@assert isa (U, AbstractBlockSparseMatrix) &&
170
127
isa (S, AbstractBlockSparseMatrix) &&
@@ -173,78 +130,92 @@ function MatrixAlgebraKit.check_input(
173
130
@assert real (eltype (A)) == eltype (S)
174
131
@assert axes (A, 1 ) == axes (U, 1 ) && axes (A, 2 ) == axes (Vᴴ, 1 ) == axes (Vᴴ, 2 )
175
132
@assert axes (S, 2 ) == axes (A, 2 )
133
+ @assert isblockdiagonal (A)
176
134
return nothing
177
135
end
178
136
179
137
function MatrixAlgebraKit. svd_compact! (
180
- A:: AbstractBlockSparseMatrix , (U, S, Vᴴ) , alg:: BlockPermutedDiagonalAlgorithm
138
+ A:: AbstractBlockSparseMatrix , USVᴴ , alg:: BlockPermutedDiagonalAlgorithm
181
139
)
182
- check_input (svd_compact!, A, (U, S, Vᴴ) )
140
+ check_input (svd_compact!, A, USVᴴ, alg )
183
141
184
- # do decomposition on each block
185
- for bI in eachblockstoredindex (A)
186
- brow, bcol = Tuple (bI)
187
- usvᴴ = (@view! (U[brow, bcol]), @view! (S[bcol, bcol]), @view! (Vᴴ[bcol, bcol]))
188
- block = @view! (A[bI])
189
- block_alg = block_algorithm (alg, block)
190
- usvᴴ′ = svd_compact! (block, usvᴴ, block_alg)
191
- @assert usvᴴ === usvᴴ′ " svd_compact! might not be in-place"
192
- end
142
+ Ad, rowperm, colperm = blockdiagonalize (A)
143
+ Ud, S, Vᴴd = svd_compact! (Ad, BlockDiagonalAlgorithm (alg))
144
+
145
+ inv_rowperm = Block .(invperm (Int .(rowperm)))
146
+ U = Ud[inv_rowperm, :]
147
+
148
+ inv_colperm = Block .(invperm (Int .(colperm)))
149
+ Vᴴ = Vᴴd[:, inv_colperm]
150
+
151
+ return U, S, Vᴴ
152
+ end
193
153
194
- # fill in identities for blocks that aren't present
195
- bIs = collect (eachblockstoredindex (A))
196
- browIs = Int .(first .(Tuple .(bIs)))
197
- bcolIs = Int .(last .(Tuple .(bIs)))
198
- emptyrows = setdiff (1 : blocksize (A, 1 ), browIs)
199
- emptycols = setdiff (1 : blocksize (A, 2 ), bcolIs)
200
- # needs copyto! instead because size(::LinearAlgebra.I) doesn't work
201
- # U[Block(row, col)] = LinearAlgebra.I
202
- # Vᴴ[Block(col, col)] = LinearAlgebra.I
203
- for (row, col) in zip (emptyrows, emptycols)
204
- copyto! (@view! (U[Block (row, col)]), LinearAlgebra. I)
205
- copyto! (@view! (Vᴴ[Block (col, col)]), LinearAlgebra. I)
154
+ function MatrixAlgebraKit. svd_compact! (
155
+ A:: AbstractBlockSparseMatrix , (U, S, Vᴴ), alg:: BlockDiagonalAlgorithm
156
+ )
157
+ check_input (svd_compact!, A, (U, S, Vᴴ), alg)
158
+
159
+ for I in 1 : min (blocksize (A)... )
160
+ bI = Block (I, I)
161
+ if isstored (blocks (A), CartesianIndex (I, I)) # TODO : isblockstored
162
+ usvᴴ = (@view! (U[bI]), @view! (S[bI]), @view! (Vᴴ[bI]))
163
+ block = @view! (A[bI])
164
+ block_alg = block_algorithm (alg, block)
165
+ usvᴴ′ = svd_compact! (block, usvᴴ, block_alg)
166
+ @assert usvᴴ === usvᴴ′ " svd_compact! might not be in-place"
167
+ else
168
+ copyto! (@view! (U[bI]), LinearAlgebra. I)
169
+ copyto! (@view! (Vᴴ[bI]), LinearAlgebra. I)
170
+ end
206
171
end
207
172
208
- return ( U, S, Vᴴ)
173
+ return U, S, Vᴴ
209
174
end
210
175
211
176
function MatrixAlgebraKit. svd_full! (
212
- A:: AbstractBlockSparseMatrix , (U, S, Vᴴ) , alg:: BlockPermutedDiagonalAlgorithm
177
+ A:: AbstractBlockSparseMatrix , USVᴴ , alg:: BlockPermutedDiagonalAlgorithm
213
178
)
214
- check_input (svd_full!, A, (U, S, Vᴴ) )
179
+ check_input (svd_full!, A, USVᴴ, alg )
215
180
216
- # do decomposition on each block
217
- for bI in eachblockstoredindex (A)
218
- brow, bcol = Tuple (bI)
219
- usvᴴ = (@view! (U[brow, bcol]), @view! (S[bcol, bcol]), @view! (Vᴴ[bcol, bcol]))
220
- block = @view! (A[bI])
221
- block_alg = block_algorithm (alg, block)
222
- usvᴴ′ = svd_full! (block, usvᴴ, block_alg)
223
- @assert usvᴴ === usvᴴ′ " svd_full! might not be in-place"
224
- end
181
+ Ad, rowperm, colperm = blockdiagonalize (A)
182
+ Ud, S, Vᴴd = svd_full! (Ad, BlockDiagonalAlgorithm (alg))
183
+
184
+ inv_rowperm = Block .(invperm (Int .(rowperm)))
185
+ U = Ud[inv_rowperm, :]
225
186
226
- # fill in identities for blocks that aren't present
227
- bIs = collect (eachblockstoredindex (A))
228
- browIs = Int .(first .(Tuple .(bIs)))
229
- bcolIs = Int .(last .(Tuple .(bIs)))
230
- emptyrows = setdiff (1 : blocksize (A, 1 ), browIs)
231
- emptycols = setdiff (1 : blocksize (A, 2 ), bcolIs)
232
- # needs copyto! instead because size(::LinearAlgebra.I) doesn't work
233
- # U[Block(row, col)] = LinearAlgebra.I
234
- # Vt[Block(col, col)] = LinearAlgebra.I
235
- for (row, col) in zip (emptyrows, emptycols)
236
- copyto! (@view! (U[Block (row, col)]), LinearAlgebra. I)
237
- copyto! (@view! (Vᴴ[Block (col, col)]), LinearAlgebra. I)
187
+ inv_colperm = Block .(invperm (Int .(colperm)))
188
+ Vᴴ = Vᴴd[:, inv_colperm]
189
+
190
+ return U, S, Vᴴ
191
+ end
192
+
193
+ function MatrixAlgebraKit. svd_full! (
194
+ A:: AbstractBlockSparseMatrix , (U, S, Vᴴ), alg:: BlockDiagonalAlgorithm
195
+ )
196
+ check_input (svd_full!, A, (U, S, Vᴴ), alg)
197
+
198
+ for I in 1 : min (blocksize (A)... )
199
+ bI = Block (I, I)
200
+ if isstored (blocks (A), CartesianIndex (I, I)) # TODO : isblockstored
201
+ usvᴴ = (@view! (U[bI]), @view! (S[bI]), @view! (Vᴴ[bI]))
202
+ block = @view! (A[bI])
203
+ block_alg = block_algorithm (alg, block)
204
+ usvᴴ′ = svd_full! (block, usvᴴ, block_alg)
205
+ @assert usvᴴ === usvᴴ′ " svd_compact! might not be in-place"
206
+ else
207
+ copyto! (@view! (U[bI]), LinearAlgebra. I)
208
+ copyto! (@view! (Vᴴ[bI]), LinearAlgebra. I)
209
+ end
238
210
end
239
211
240
- # also handle extra rows/cols
241
- for i in ( length (emptyrows) + 1 ) : length (emptycols )
242
- copyto! (@view! (Vᴴ [Block (emptycols[i], emptycols[i] )]), LinearAlgebra. I)
212
+ # Complete the unitaries for rectangular inputs
213
+ for I in blocksize (A, 2 ) + 1 : blocksize (A, 1 )
214
+ copyto! (@view! (U [Block (I, I )]), LinearAlgebra. I)
243
215
end
244
- bn = blocksize (A, 2 )
245
- for (i, k) in enumerate ((length (emptycols) + 1 ): length (emptyrows))
246
- copyto! (@view! (U[Block (emptyrows[k], bn + i)]), LinearAlgebra. I)
216
+ for I in blocksize (A, 1 )+ 1 : blocksize (A, 2 )
217
+ copyto! (@view! (Vᴴ[Block (I, I)]), LinearAlgebra. I)
247
218
end
248
219
249
- return ( U, S, Vᴴ)
220
+ return U, S, Vᴴ
250
221
end
0 commit comments