@@ -4,7 +4,6 @@ using LinearAlgebra: LinearAlgebra, Diagonal
4
4
using MatrixAlgebraKit:
5
5
MatrixAlgebraKit,
6
6
TruncationStrategy,
7
- check_input,
8
7
default_eig_algorithm,
9
8
default_eigh_algorithm,
10
9
diagview,
@@ -26,60 +25,100 @@ for f in [:default_eig_algorithm, :default_eigh_algorithm]
26
25
end
27
26
end
28
27
28
+ function output_type (:: typeof (eig_full!), A:: Type{<:AbstractMatrix{T}} ) where {T}
29
+ DV = Base. promote_op (eig_full!, A)
30
+ return if isconcretetype (DV)
31
+ DV
32
+ else
33
+ Tuple{AbstractMatrix{complex (T)},AbstractMatrix{complex (T)}}
34
+ end
35
+ end
36
+ function output_type (:: typeof (eigh_full!), A:: Type{<:AbstractMatrix{T}} ) where {T}
37
+ DV = Base. promote_op (eigh_full!, A)
38
+ return isconcretetype (DV) ? DV : Tuple{AbstractMatrix{real (T)},AbstractMatrix{T}}
39
+ end
40
+
41
+ function MatrixAlgebraKit. initialize_output (
42
+ :: Union{typeof(eig_full!),typeof(eigh_full!)} ,
43
+ :: AbstractBlockSparseMatrix ,
44
+ :: BlockPermutedDiagonalAlgorithm ,
45
+ )
46
+ return nothing
47
+ end
48
+
49
+ function MatrixAlgebraKit. check_input (
50
+ :: Union{typeof(eig_full!),typeof(eigh_full!)} ,
51
+ A:: AbstractBlockSparseMatrix ,
52
+ DV,
53
+ :: BlockPermutedDiagonalAlgorithm ,
54
+ )
55
+ @assert isblockpermuteddiagonal (A)
56
+ return nothing
57
+ end
29
58
function MatrixAlgebraKit. check_input (
30
- :: typeof (eig_full!), A:: AbstractBlockSparseMatrix , (D, V)
59
+ :: typeof (eig_full!), A:: AbstractBlockSparseMatrix , (D, V), :: BlockDiagonalAlgorithm
31
60
)
32
61
@assert isa (D, AbstractBlockSparseMatrix) && isa (V, AbstractBlockSparseMatrix)
33
62
@assert eltype (V) === eltype (D) === complex (eltype (A))
34
63
@assert axes (A, 1 ) == axes (A, 2 )
35
64
@assert axes (A) == axes (D) == axes (V)
65
+ @assert isblockdiagonal (A)
36
66
return nothing
37
67
end
38
68
function MatrixAlgebraKit. check_input (
39
- :: typeof (eigh_full!), A:: AbstractBlockSparseMatrix , (D, V)
69
+ :: typeof (eigh_full!), A:: AbstractBlockSparseMatrix , (D, V), :: BlockDiagonalAlgorithm
40
70
)
41
71
@assert isa (D, AbstractBlockSparseMatrix) && isa (V, AbstractBlockSparseMatrix)
42
72
@assert eltype (V) === eltype (A)
43
73
@assert eltype (D) === real (eltype (A))
44
74
@assert axes (A, 1 ) == axes (A, 2 )
45
75
@assert axes (A) == axes (D) == axes (V)
76
+ @assert isblockdiagonal (A)
46
77
return nothing
47
78
end
48
79
49
- function output_type (f:: typeof (eig_full!), A:: Type{<:AbstractMatrix{T}} ) where {T}
50
- DV = Base. promote_op (f, A)
51
- ! isconcretetype (DV) && return Tuple{AbstractMatrix{complex (T)},AbstractMatrix{complex (T)}}
52
- return DV
53
- end
54
- function output_type (f:: typeof (eigh_full!), A:: Type{<:AbstractMatrix{T}} ) where {T}
55
- DV = Base. promote_op (f, A)
56
- ! isconcretetype (DV) && return Tuple{AbstractMatrix{real (T)},AbstractMatrix{T}}
57
- return DV
58
- end
59
-
60
80
for f in [:eig_full! , :eigh_full! ]
61
81
@eval begin
62
82
function MatrixAlgebraKit. initialize_output (
63
83
:: typeof ($ f), A:: AbstractBlockSparseMatrix , alg:: BlockPermutedDiagonalAlgorithm
84
+ )
85
+ return nothing
86
+ end
87
+ function MatrixAlgebraKit. initialize_output (
88
+ :: typeof ($ f), A:: AbstractBlockSparseMatrix , alg:: BlockDiagonalAlgorithm
64
89
)
65
90
Td, Tv = fieldtypes (output_type ($ f, blocktype (A)))
66
91
D = similar (A, BlockType (Td))
67
92
V = similar (A, BlockType (Tv))
68
93
return (D, V)
69
94
end
70
95
function MatrixAlgebraKit. $f (
71
- A:: AbstractBlockSparseMatrix , (D, V) , alg:: BlockPermutedDiagonalAlgorithm
96
+ A:: AbstractBlockSparseMatrix , DV , alg:: BlockPermutedDiagonalAlgorithm
72
97
)
73
- check_input ($ f, A, (D, V))
74
- for I in eachstoredblockdiagindex (A)
75
- block = @view! (A[I])
76
- block_alg = block_algorithm (alg, block)
77
- D[I], V[I] = $ f (block, block_alg)
78
- end
79
- for I in eachunstoredblockdiagindex (A)
80
- # TODO : Support setting `LinearAlgebra.I` directly, and/or
81
- # using `FillArrays.Eye`.
82
- V[I] = LinearAlgebra. I (size (@view (V[I]), 1 ))
98
+ MatrixAlgebraKit. check_input ($ f, A, DV, alg)
99
+ Ad, transform_rows, transform_cols = blockdiagonalize (A)
100
+ Dd, Vd = $ f (Ad, BlockDiagonalAlgorithm (alg))
101
+ D = transform_rows (Dd)
102
+ V = transform_cols (Vd)
103
+ return D, V
104
+ end
105
+ function MatrixAlgebraKit. $f (
106
+ A:: AbstractBlockSparseMatrix , (D, V), alg:: BlockDiagonalAlgorithm
107
+ )
108
+ MatrixAlgebraKit. check_input ($ f, A, (D, V), alg)
109
+
110
+ # do decomposition on each block
111
+ for I in 1 : min (blocksize (A)... )
112
+ bI = Block (I, I)
113
+ if isstored (blocks (A), CartesianIndex (I, I)) # TODO : isblockstored
114
+ block = @view! (A[bI])
115
+ block_alg = block_algorithm (alg, block)
116
+ bD, bV = $ f (block, block_alg)
117
+ D[bI] = bD
118
+ V[bI] = bV
119
+ else
120
+ copyto! (@view! (V[bI]), LinearAlgebra. I)
121
+ end
83
122
end
84
123
return (D, V)
85
124
end
0 commit comments