Skip to content

Commit 710939b

Browse files
committed
Refactor eig
1 parent b7286b1 commit 710939b

File tree

1 file changed

+64
-25
lines changed

1 file changed

+64
-25
lines changed

src/factorizations/eig.jl

Lines changed: 64 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,6 @@ using LinearAlgebra: LinearAlgebra, Diagonal
44
using MatrixAlgebraKit:
55
MatrixAlgebraKit,
66
TruncationStrategy,
7-
check_input,
87
default_eig_algorithm,
98
default_eigh_algorithm,
109
diagview,
@@ -26,60 +25,100 @@ for f in [:default_eig_algorithm, :default_eigh_algorithm]
2625
end
2726
end
2827

28+
function output_type(::typeof(eig_full!), A::Type{<:AbstractMatrix{T}}) where {T}
29+
DV = Base.promote_op(eig_full!, A)
30+
return if isconcretetype(DV)
31+
DV
32+
else
33+
Tuple{AbstractMatrix{complex(T)},AbstractMatrix{complex(T)}}
34+
end
35+
end
36+
function output_type(::typeof(eigh_full!), A::Type{<:AbstractMatrix{T}}) where {T}
37+
DV = Base.promote_op(eigh_full!, A)
38+
return isconcretetype(DV) ? DV : Tuple{AbstractMatrix{real(T)},AbstractMatrix{T}}
39+
end
40+
41+
function MatrixAlgebraKit.initialize_output(
42+
::Union{typeof(eig_full!),typeof(eigh_full!)},
43+
::AbstractBlockSparseMatrix,
44+
::BlockPermutedDiagonalAlgorithm,
45+
)
46+
return nothing
47+
end
48+
49+
function MatrixAlgebraKit.check_input(
50+
::Union{typeof(eig_full!),typeof(eigh_full!)},
51+
A::AbstractBlockSparseMatrix,
52+
DV,
53+
::BlockPermutedDiagonalAlgorithm,
54+
)
55+
@assert isblockpermuteddiagonal(A)
56+
return nothing
57+
end
2958
function MatrixAlgebraKit.check_input(
30-
::typeof(eig_full!), A::AbstractBlockSparseMatrix, (D, V)
59+
::typeof(eig_full!), A::AbstractBlockSparseMatrix, (D, V), ::BlockDiagonalAlgorithm
3160
)
3261
@assert isa(D, AbstractBlockSparseMatrix) && isa(V, AbstractBlockSparseMatrix)
3362
@assert eltype(V) === eltype(D) === complex(eltype(A))
3463
@assert axes(A, 1) == axes(A, 2)
3564
@assert axes(A) == axes(D) == axes(V)
65+
@assert isblockdiagonal(A)
3666
return nothing
3767
end
3868
function MatrixAlgebraKit.check_input(
39-
::typeof(eigh_full!), A::AbstractBlockSparseMatrix, (D, V)
69+
::typeof(eigh_full!), A::AbstractBlockSparseMatrix, (D, V), ::BlockDiagonalAlgorithm
4070
)
4171
@assert isa(D, AbstractBlockSparseMatrix) && isa(V, AbstractBlockSparseMatrix)
4272
@assert eltype(V) === eltype(A)
4373
@assert eltype(D) === real(eltype(A))
4474
@assert axes(A, 1) == axes(A, 2)
4575
@assert axes(A) == axes(D) == axes(V)
76+
@assert isblockdiagonal(A)
4677
return nothing
4778
end
4879

49-
function output_type(f::typeof(eig_full!), A::Type{<:AbstractMatrix{T}}) where {T}
50-
DV = Base.promote_op(f, A)
51-
!isconcretetype(DV) && return Tuple{AbstractMatrix{complex(T)},AbstractMatrix{complex(T)}}
52-
return DV
53-
end
54-
function output_type(f::typeof(eigh_full!), A::Type{<:AbstractMatrix{T}}) where {T}
55-
DV = Base.promote_op(f, A)
56-
!isconcretetype(DV) && return Tuple{AbstractMatrix{real(T)},AbstractMatrix{T}}
57-
return DV
58-
end
59-
6080
for f in [:eig_full!, :eigh_full!]
6181
@eval begin
6282
function MatrixAlgebraKit.initialize_output(
6383
::typeof($f), A::AbstractBlockSparseMatrix, alg::BlockPermutedDiagonalAlgorithm
84+
)
85+
return nothing
86+
end
87+
function MatrixAlgebraKit.initialize_output(
88+
::typeof($f), A::AbstractBlockSparseMatrix, alg::BlockDiagonalAlgorithm
6489
)
6590
Td, Tv = fieldtypes(output_type($f, blocktype(A)))
6691
D = similar(A, BlockType(Td))
6792
V = similar(A, BlockType(Tv))
6893
return (D, V)
6994
end
7095
function MatrixAlgebraKit.$f(
71-
A::AbstractBlockSparseMatrix, (D, V), alg::BlockPermutedDiagonalAlgorithm
96+
A::AbstractBlockSparseMatrix, DV, alg::BlockPermutedDiagonalAlgorithm
7297
)
73-
check_input($f, A, (D, V))
74-
for I in eachstoredblockdiagindex(A)
75-
block = @view!(A[I])
76-
block_alg = block_algorithm(alg, block)
77-
D[I], V[I] = $f(block, block_alg)
78-
end
79-
for I in eachunstoredblockdiagindex(A)
80-
# TODO: Support setting `LinearAlgebra.I` directly, and/or
81-
# using `FillArrays.Eye`.
82-
V[I] = LinearAlgebra.I(size(@view(V[I]), 1))
98+
MatrixAlgebraKit.check_input($f, A, DV, alg)
99+
Ad, transform_rows, transform_cols = blockdiagonalize(A)
100+
Dd, Vd = $f(Ad, BlockDiagonalAlgorithm(alg))
101+
D = transform_rows(Dd)
102+
V = transform_cols(Vd)
103+
return D, V
104+
end
105+
function MatrixAlgebraKit.$f(
106+
A::AbstractBlockSparseMatrix, (D, V), alg::BlockDiagonalAlgorithm
107+
)
108+
MatrixAlgebraKit.check_input($f, A, (D, V), alg)
109+
110+
# do decomposition on each block
111+
for I in 1:min(blocksize(A)...)
112+
bI = Block(I, I)
113+
if isstored(blocks(A), CartesianIndex(I, I)) # TODO: isblockstored
114+
block = @view!(A[bI])
115+
block_alg = block_algorithm(alg, block)
116+
bD, bV = $f(block, block_alg)
117+
D[bI] = bD
118+
V[bI] = bV
119+
else
120+
copyto!(@view!(V[bI]), LinearAlgebra.I)
121+
end
83122
end
84123
return (D, V)
85124
end

0 commit comments

Comments
 (0)