1
1
using DiagonalArrays: diagonaltype
2
2
using MatrixAlgebraKit:
3
- MatrixAlgebraKit, check_input, default_svd_algorithm, svd_compact!, svd_full!
3
+ MatrixAlgebraKit, check_input, default_svd_algorithm, svd_compact!, svd_full!, svd_vals!
4
4
using TypeParameterAccessors: realtype
5
5
6
6
function MatrixAlgebraKit. default_svd_algorithm (
@@ -15,7 +15,15 @@ function output_type(
15
15
f:: Union{typeof(svd_compact!),typeof(svd_full!)} , A:: Type{<:AbstractMatrix{T}}
16
16
) where {T}
17
17
USVᴴ = Base. promote_op (f, A)
18
- return isconcretetype (USVᴴ) ? USVᴴ : Tuple{AbstractMatrix{T},AbstractMatrix{realtype (T)},AbstractMatrix{T}}
18
+ return if isconcretetype (USVᴴ)
19
+ USVᴴ
20
+ else
21
+ Tuple{AbstractMatrix{T},AbstractMatrix{realtype (T)},AbstractMatrix{T}}
22
+ end
23
+ end
24
+ function output_type (:: typeof (svd_vals!), A:: Type{<:AbstractMatrix{T}} ) where {T}
25
+ S = Base. promote_op (svd_vals!, A)
26
+ return isconcretetype (S) ? S : AbstractVector{real (T)}
19
27
end
20
28
21
29
function MatrixAlgebraKit. initialize_output (
@@ -46,7 +54,6 @@ function MatrixAlgebraKit.initialize_output(
46
54
)
47
55
return nothing
48
56
end
49
-
50
57
function MatrixAlgebraKit. initialize_output (
51
58
:: typeof (svd_full!), A:: AbstractBlockSparseMatrix , alg:: BlockDiagonalAlgorithm
52
59
)
@@ -58,6 +65,24 @@ function MatrixAlgebraKit.initialize_output(
58
65
return U, S, Vᴴ
59
66
end
60
67
68
+ function MatrixAlgebraKit. initialize_output (
69
+ :: typeof (svd_vals!), :: AbstractBlockSparseMatrix , :: BlockDiagonalAlgorithm
70
+ )
71
+ return nothing
72
+ end
73
+ function MatrixAlgebraKit. initialize_output (
74
+ :: typeof (svd_vals!), A:: AbstractBlockSparseMatrix , alg:: BlockDiagonalAlgorithm
75
+ )
76
+ brows = eachblockaxis (axes (A, 1 ))
77
+ bcols = eachblockaxis (axes (A, 2 ))
78
+ # using the property that zip stops as soon as one of the iterators is exhausted
79
+ s_axes = map (splat (infimum), zip (brows, bcols))
80
+ s_axis = mortar_axis (s_axes)
81
+
82
+ BS = output_type (svd_vals!, blocktype (A))
83
+ return similar (A, BlockType (BS), S_axes)
84
+ end
85
+
61
86
function MatrixAlgebraKit. check_input (
62
87
:: typeof (svd_compact!),
63
88
A:: AbstractBlockSparseMatrix ,
@@ -66,7 +91,6 @@ function MatrixAlgebraKit.check_input(
66
91
)
67
92
@assert isblockpermuteddiagonal (A)
68
93
end
69
-
70
94
function MatrixAlgebraKit. check_input (
71
95
:: typeof (svd_compact!), A:: AbstractBlockSparseMatrix , (U, S, Vᴴ), :: BlockDiagonalAlgorithm
72
96
)
@@ -87,7 +111,6 @@ function MatrixAlgebraKit.check_input(
87
111
@assert isblockpermuteddiagonal (A)
88
112
return nothing
89
113
end
90
-
91
114
function MatrixAlgebraKit. check_input (
92
115
:: typeof (svd_full!), A:: AbstractBlockSparseMatrix , (U, S, Vᴴ), :: BlockDiagonalAlgorithm
93
116
)
@@ -102,15 +125,30 @@ function MatrixAlgebraKit.check_input(
102
125
return nothing
103
126
end
104
127
128
+ function MatrixAlgebraKit. check_input (
129
+ :: typeof (svd_vals!), A:: AbstractBlockSparseMatrix , S, :: BlockPermutedDiagonalAlgorithm
130
+ )
131
+ @assert isblockpermuteddiagonal (A)
132
+ return nothing
133
+ end
134
+ function MatrixAlgebraKit. check_input (
135
+ :: typeof (svd_vals!), A:: AbstractBlockSparseMatrix , S, :: BlockDiagonalAlgorithm
136
+ )
137
+ @assert isa (S, AbstractBlockSparseVector)
138
+ @assert real (eltype (A)) == eltype (S)
139
+ @assert isblockdiagonal (A)
140
+ return nothing
141
+ end
142
+
105
143
function MatrixAlgebraKit. svd_compact! (
106
144
A:: AbstractBlockSparseMatrix , USVᴴ, alg:: BlockPermutedDiagonalAlgorithm
107
145
)
108
146
check_input (svd_compact!, A, USVᴴ, alg)
109
147
110
- Ad, transform_rows, transform_cols = blockdiagonalize (A)
148
+ Ad, (invrowperm, invcolperm) = blockdiagonalize (A)
111
149
Ud, S, Vᴴd = svd_compact! (Ad, BlockDiagonalAlgorithm (alg))
112
- U = transform_rows (Ud)
113
- Vᴴ = transform_cols (Vᴴd)
150
+ U = transform_rows (Ud, invrowperm )
151
+ Vᴴ = transform_cols (Vᴴd, invcolperm )
114
152
115
153
return U, S, Vᴴ
116
154
end
@@ -143,10 +181,10 @@ function MatrixAlgebraKit.svd_full!(
143
181
)
144
182
check_input (svd_full!, A, USVᴴ, alg)
145
183
146
- Ad, transform_rows, transform_cols = blockdiagonalize (A)
184
+ Ad, (invrowperm, invcolperm) = blockdiagonalize (A)
147
185
Ud, S, Vᴴd = svd_full! (Ad, BlockDiagonalAlgorithm (alg))
148
- U = transform_rows (Ud)
149
- Vᴴ = transform_cols (Vᴴd)
186
+ U = transform_rows (Ud, invrowperm )
187
+ Vᴴ = transform_cols (Vᴴd, invcolperm )
150
188
151
189
return U, S, Vᴴ
152
190
end
@@ -181,3 +219,21 @@ function MatrixAlgebraKit.svd_full!(
181
219
182
220
return U, S, Vᴴ
183
221
end
222
+
223
+ function MatrixAlgebraKit. svd_vals! (
224
+ A:: AbstractBlockSparseMatrix , S, alg:: BlockPermutedDiagonalAlgorithm
225
+ )
226
+ MatrixAlgebraKit. check_input (svd_vals!, A, S, alg)
227
+ Ad, _ = blockdiagonalize (A)
228
+ return svd_vals! (Ad, BlockDiagonalAlgorithm (alg))
229
+ end
230
+ function MatrixAlgebraKit. svd_vals! (
231
+ A:: AbstractBlockSparseMatrix , S, alg:: BlockDiagonalAlgorithm
232
+ )
233
+ MatrixAlgebraKit. check_input (svd_vals!, A, S, alg)
234
+ for I in eachblockstoredindex (A)
235
+ block = @view! (A[I])
236
+ S[Tuple (I)[1 ]] = $ f (block, block_algorithm (alg, block))
237
+ end
238
+ return S
239
+ end
0 commit comments