@@ -11,21 +11,11 @@ function MatrixAlgebraKit.default_svd_algorithm(
11
11
end
12
12
end
13
13
14
- function output_type (:: typeof (svd_compact!), A:: Type{<:AbstractMatrix{T}} ) where {T}
15
- USVᴴ = Base. promote_op (svd_compact!, A)
16
- ! isconcretetype (USVᴴ) &&
17
- return Tuple{AbstractMatrix{T},AbstractMatrix{realtype (T)},AbstractMatrix{T}}
18
- return USVᴴ
19
- end
20
-
21
- function similar_output (
22
- :: typeof (svd_compact!), A, S_axes, alg:: MatrixAlgebraKit.AbstractAlgorithm
23
- )
24
- BU, BS, BVᴴ = fieldtypes (output_type (svd_compact!, blocktype (A)))
25
- U = similar (A, BlockType (BU), (axes (A, 1 ), S_axes[1 ]))
26
- S = similar (A, BlockType (BS), S_axes)
27
- Vᴴ = similar (A, BlockType (BVᴴ), (S_axes[2 ], axes (A, 2 )))
28
- return U, S, Vᴴ
14
+ function output_type (
15
+ f:: Union{typeof(svd_compact!),typeof(svd_full!)} , A:: Type{<:AbstractMatrix{T}}
16
+ ) where {T}
17
+ USVᴴ = Base. promote_op (f, A)
18
+ return isconcretetype (USVᴴ) ? USVᴴ : Tuple{AbstractMatrix{T},AbstractMatrix{realtype (T)},AbstractMatrix{T}}
29
19
end
30
20
31
21
function MatrixAlgebraKit. initialize_output (
@@ -42,28 +32,13 @@ function MatrixAlgebraKit.initialize_output(
42
32
s_axes = map (splat (infimum), zip (brows, bcols))
43
33
s_axis = mortar_axis (s_axes)
44
34
S_axes = (s_axis, s_axis)
45
- U, S, Vᴴ = similar_output (svd_compact!, A, S_axes, alg)
46
-
47
- for bI in eachblockstoredindex (A)
48
- block = @view! (A[bI])
49
- block_alg = block_algorithm (alg, block)
50
- I = first (Tuple (bI)) # == last(Tuple(bI))
51
- U[I, I], S[I, I], Vᴴ[I, I] = MatrixAlgebraKit. initialize_output (
52
- svd_compact!, block, block_alg
53
- )
54
- end
55
35
56
- return U, S, Vᴴ
57
- end
36
+ BU, BS, BVᴴ = fieldtypes (output_type (svd_compact!, blocktype (A)))
37
+ U = similar (A, BlockType (BU), (axes (A, 1 ), S_axes[1 ]))
38
+ S = similar (A, BlockType (BS), S_axes)
39
+ Vᴴ = similar (A, BlockType (BVᴴ), (S_axes[2 ], axes (A, 2 )))
58
40
59
- function similar_output (
60
- :: typeof (svd_full!), A, S_axes, alg:: MatrixAlgebraKit.AbstractAlgorithm
61
- )
62
- U = similar (A, axes (A, 1 ), S_axes[1 ])
63
- T = real (eltype (A))
64
- S = similar (A, T, S_axes)
65
- Vt = similar (A, S_axes[2 ], axes (A, 2 ))
66
- return U, S, Vt
41
+ return U, S, Vᴴ
67
42
end
68
43
69
44
function MatrixAlgebraKit. initialize_output (
75
50
function MatrixAlgebraKit. initialize_output (
76
51
:: typeof (svd_full!), A:: AbstractBlockSparseMatrix , alg:: BlockDiagonalAlgorithm
77
52
)
78
- U, S, Vᴴ = similar_output (svd_full!, A, axes (A), alg)
79
-
80
- for bI in eachblockstoredindex (A)
81
- block = @view! (A[bI])
82
- block_alg = block_algorithm (alg, block)
83
- I = first (Tuple (bI)) # == last(Tuple(bI))
84
- U[I, I], S[I, I], Vᴴ[I, I] = MatrixAlgebraKit. initialize_output (
85
- svd_full!, block, block_alg
86
- )
87
- end
53
+ BU, BS, BVᴴ = fieldtypes (output_type (svd_full!, blocktype (A)))
54
+ U = similar (A, BlockType (BU), (axes (A, 1 ), axes (A, 1 )))
55
+ S = similar (A, BlockType (BS), axes (A))
56
+ Vᴴ = similar (A, BlockType (BVᴴ), (axes (A, 2 ), axes (A, 2 )))
88
57
89
58
return U, S, Vᴴ
90
59
end
@@ -154,11 +123,12 @@ function MatrixAlgebraKit.svd_compact!(
154
123
for I in 1 : min (blocksize (A)... )
155
124
bI = Block (I, I)
156
125
if isstored (blocks (A), CartesianIndex (I, I)) # TODO : isblockstored
157
- usvᴴ = (@view! (U[bI]), @view! (S[bI]), @view! (Vᴴ[bI]))
158
126
block = @view! (A[bI])
159
127
block_alg = block_algorithm (alg, block)
160
- usvᴴ′ = svd_compact! (block, usvᴴ, block_alg)
161
- @assert usvᴴ === usvᴴ′ " svd_compact! might not be in-place"
128
+ bU, bS, bVᴴ = svd_compact! (block, block_alg)
129
+ U[bI] = bU
130
+ S[bI] = bS
131
+ Vᴴ[bI] = bVᴴ
162
132
else
163
133
copyto! (@view! (U[bI]), LinearAlgebra. I)
164
134
copyto! (@view! (Vᴴ[bI]), LinearAlgebra. I)
@@ -189,11 +159,12 @@ function MatrixAlgebraKit.svd_full!(
189
159
for I in 1 : min (blocksize (A)... )
190
160
bI = Block (I, I)
191
161
if isstored (blocks (A), CartesianIndex (I, I)) # TODO : isblockstored
192
- usvᴴ = (@view! (U[bI]), @view! (S[bI]), @view! (Vᴴ[bI]))
193
162
block = @view! (A[bI])
194
163
block_alg = block_algorithm (alg, block)
195
- usvᴴ′ = svd_full! (block, usvᴴ, block_alg)
196
- @assert usvᴴ === usvᴴ′ " svd_compact! might not be in-place"
164
+ bU, bS, bVᴴ = svd_full! (block, block_alg)
165
+ U[bI] = bU
166
+ S[bI] = bS
167
+ Vᴴ[bI] = bVᴴ
197
168
else
198
169
copyto! (@view! (U[bI]), LinearAlgebra. I)
199
170
copyto! (@view! (Vᴴ[bI]), LinearAlgebra. I)
0 commit comments