Skip to content

Commit 2fa9476

Browse files
committed
simplify initializations
1 parent 4c632e8 commit 2fa9476

File tree

1 file changed

+22
-51
lines changed

1 file changed

+22
-51
lines changed

src/factorizations/svd.jl

Lines changed: 22 additions & 51 deletions
Original file line numberDiff line numberDiff line change
@@ -11,21 +11,11 @@ function MatrixAlgebraKit.default_svd_algorithm(
1111
end
1212
end
1313

14-
function output_type(::typeof(svd_compact!), A::Type{<:AbstractMatrix{T}}) where {T}
15-
USVᴴ = Base.promote_op(svd_compact!, A)
16-
!isconcretetype(USVᴴ) &&
17-
return Tuple{AbstractMatrix{T},AbstractMatrix{realtype(T)},AbstractMatrix{T}}
18-
return USVᴴ
19-
end
20-
21-
function similar_output(
22-
::typeof(svd_compact!), A, S_axes, alg::MatrixAlgebraKit.AbstractAlgorithm
23-
)
24-
BU, BS, BVᴴ = fieldtypes(output_type(svd_compact!, blocktype(A)))
25-
U = similar(A, BlockType(BU), (axes(A, 1), S_axes[1]))
26-
S = similar(A, BlockType(BS), S_axes)
27-
Vᴴ = similar(A, BlockType(BVᴴ), (S_axes[2], axes(A, 2)))
28-
return U, S, Vᴴ
14+
function output_type(
15+
f::Union{typeof(svd_compact!),typeof(svd_full!)}, A::Type{<:AbstractMatrix{T}}
16+
) where {T}
17+
USVᴴ = Base.promote_op(f, A)
18+
return isconcretetype(USVᴴ) ? USVᴴ : Tuple{AbstractMatrix{T},AbstractMatrix{realtype(T)},AbstractMatrix{T}}
2919
end
3020

3121
function MatrixAlgebraKit.initialize_output(
@@ -42,28 +32,13 @@ function MatrixAlgebraKit.initialize_output(
4232
s_axes = map(splat(infimum), zip(brows, bcols))
4333
s_axis = mortar_axis(s_axes)
4434
S_axes = (s_axis, s_axis)
45-
U, S, Vᴴ = similar_output(svd_compact!, A, S_axes, alg)
46-
47-
for bI in eachblockstoredindex(A)
48-
block = @view!(A[bI])
49-
block_alg = block_algorithm(alg, block)
50-
I = first(Tuple(bI)) # == last(Tuple(bI))
51-
U[I, I], S[I, I], Vᴴ[I, I] = MatrixAlgebraKit.initialize_output(
52-
svd_compact!, block, block_alg
53-
)
54-
end
5535

56-
return U, S, Vᴴ
57-
end
36+
BU, BS, BVᴴ = fieldtypes(output_type(svd_compact!, blocktype(A)))
37+
U = similar(A, BlockType(BU), (axes(A, 1), S_axes[1]))
38+
S = similar(A, BlockType(BS), S_axes)
39+
Vᴴ = similar(A, BlockType(BVᴴ), (S_axes[2], axes(A, 2)))
5840

59-
function similar_output(
60-
::typeof(svd_full!), A, S_axes, alg::MatrixAlgebraKit.AbstractAlgorithm
61-
)
62-
U = similar(A, axes(A, 1), S_axes[1])
63-
T = real(eltype(A))
64-
S = similar(A, T, S_axes)
65-
Vt = similar(A, S_axes[2], axes(A, 2))
66-
return U, S, Vt
41+
return U, S, Vᴴ
6742
end
6843

6944
function MatrixAlgebraKit.initialize_output(
@@ -75,16 +50,10 @@ end
7550
function MatrixAlgebraKit.initialize_output(
7651
::typeof(svd_full!), A::AbstractBlockSparseMatrix, alg::BlockDiagonalAlgorithm
7752
)
78-
U, S, Vᴴ = similar_output(svd_full!, A, axes(A), alg)
79-
80-
for bI in eachblockstoredindex(A)
81-
block = @view!(A[bI])
82-
block_alg = block_algorithm(alg, block)
83-
I = first(Tuple(bI)) # == last(Tuple(bI))
84-
U[I, I], S[I, I], Vᴴ[I, I] = MatrixAlgebraKit.initialize_output(
85-
svd_full!, block, block_alg
86-
)
87-
end
53+
BU, BS, BVᴴ = fieldtypes(output_type(svd_full!, blocktype(A)))
54+
U = similar(A, BlockType(BU), (axes(A, 1), axes(A, 1)))
55+
S = similar(A, BlockType(BS), axes(A))
56+
Vᴴ = similar(A, BlockType(BVᴴ), (axes(A, 2), axes(A, 2)))
8857

8958
return U, S, Vᴴ
9059
end
@@ -154,11 +123,12 @@ function MatrixAlgebraKit.svd_compact!(
154123
for I in 1:min(blocksize(A)...)
155124
bI = Block(I, I)
156125
if isstored(blocks(A), CartesianIndex(I, I)) # TODO: isblockstored
157-
usvᴴ = (@view!(U[bI]), @view!(S[bI]), @view!(Vᴴ[bI]))
158126
block = @view!(A[bI])
159127
block_alg = block_algorithm(alg, block)
160-
usvᴴ′ = svd_compact!(block, usvᴴ, block_alg)
161-
@assert usvᴴ === usvᴴ′ "svd_compact! might not be in-place"
128+
bU, bS, bVᴴ = svd_compact!(block, block_alg)
129+
U[bI] = bU
130+
S[bI] = bS
131+
Vᴴ[bI] = bVᴴ
162132
else
163133
copyto!(@view!(U[bI]), LinearAlgebra.I)
164134
copyto!(@view!(Vᴴ[bI]), LinearAlgebra.I)
@@ -189,11 +159,12 @@ function MatrixAlgebraKit.svd_full!(
189159
for I in 1:min(blocksize(A)...)
190160
bI = Block(I, I)
191161
if isstored(blocks(A), CartesianIndex(I, I)) # TODO: isblockstored
192-
usvᴴ = (@view!(U[bI]), @view!(S[bI]), @view!(Vᴴ[bI]))
193162
block = @view!(A[bI])
194163
block_alg = block_algorithm(alg, block)
195-
usvᴴ′ = svd_full!(block, usvᴴ, block_alg)
196-
@assert usvᴴ === usvᴴ′ "svd_compact! might not be in-place"
164+
bU, bS, bVᴴ = svd_full!(block, block_alg)
165+
U[bI] = bU
166+
S[bI] = bS
167+
Vᴴ[bI] = bVᴴ
197168
else
198169
copyto!(@view!(U[bI]), LinearAlgebra.I)
199170
copyto!(@view!(Vᴴ[bI]), LinearAlgebra.I)

0 commit comments

Comments
 (0)