Skip to content

Commit 8af8e5b

Browse files
committed
More specific block type, code sharing across Hermitian and non-Hermitian
1 parent 7077979 commit 8af8e5b

File tree

2 files changed

+34
-37
lines changed

2 files changed

+34
-37
lines changed

src/abstractblocksparsearray/wrappedabstractblocksparsearray.jl

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -335,6 +335,12 @@ function Base.similar(
335335
return @interface BlockSparseArrayInterface() similar(a, elt, axes)
336336
end
337337

338+
struct BlockType{T} end
339+
BlockType(x) = BlockType{x}()
340+
function Base.similar(a::AbstractBlockSparseArray, ::BlockType{T}) where {T}
341+
return BlockSparseArray{eltype(T),ndims(T),T}(undef, axes(a))
342+
end
343+
338344
# TODO: Implement this in a more generic way using a smarter `copyto!`,
339345
# which is ultimately what `Array{T,N}(::AbstractArray{<:Any,N})` calls.
340346
# These are defined for now to avoid scalar indexing issues when there

src/factorizations/eig.jl

Lines changed: 28 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -1,54 +1,45 @@
11
using MatrixAlgebraKit:
22
MatrixAlgebraKit, default_eig_algorithm, default_eigh_algorithm, eig_full!, eigh_full!
33

4-
function MatrixAlgebraKit.default_eig_algorithm(
5-
arrayt::Type{<:AbstractBlockSparseMatrix}; kwargs...
4+
function initialize_blocksparse_eig_output(
5+
f, A::AbstractMatrix, alg::BlockPermutedDiagonalAlgorithm
66
)
7-
alg = default_eig_algorithm(blocktype(arrayt); kwargs...)
8-
return BlockPermutedDiagonalAlgorithm(alg)
9-
end
10-
11-
function MatrixAlgebraKit.initialize_output(
12-
::typeof(eig_full!), A::AbstractBlockSparseMatrix, alg::BlockPermutedDiagonalAlgorithm
13-
)
14-
D = similar(A, complex(eltype(A)))
15-
V = similar(A, complex(eltype(A)))
7+
Td, Tv = fieldtypes(Base.promote_op(f, blocktype(A), typeof(alg.alg)))
8+
D = similar(A, BlockType(Td))
9+
V = similar(A, BlockType(Tv))
1610
return (D, V)
1711
end
1812

19-
function MatrixAlgebraKit.eig_full!(
20-
A::AbstractBlockSparseMatrix, (D, V), alg::BlockPermutedDiagonalAlgorithm
13+
function blocksparse_eig_full!(
14+
f, A::AbstractMatrix, (D, V), alg::BlockPermutedDiagonalAlgorithm
2115
)
2216
for I in blockdiagindices(A)
23-
d, v = eig_full!(A[I], alg.alg)
24-
D[I] = d
25-
V[I] = v
17+
d, v = f(@view!(A[I]), alg.alg)
18+
D[I], V[I] = d, v
2619
end
2720
return (D, V)
2821
end
2922

30-
function MatrixAlgebraKit.default_eigh_algorithm(
31-
arrayt::Type{<:AbstractBlockSparseMatrix}; kwargs...
32-
)
33-
alg = default_eigh_algorithm(blocktype(arrayt); kwargs...)
34-
return BlockPermutedDiagonalAlgorithm(alg)
35-
end
36-
37-
function MatrixAlgebraKit.initialize_output(
38-
::typeof(eigh_full!), A::AbstractBlockSparseMatrix, alg::BlockPermutedDiagonalAlgorithm
39-
)
40-
D = similar(A, complex(eltype(A)))
41-
V = similar(A, complex(eltype(A)))
42-
return (D, V)
23+
for f in [:default_eig_algorithm, :default_eigh_algorithm]
24+
@eval begin
25+
function MatrixAlgebraKit.$f(arrayt::Type{<:AbstractBlockSparseMatrix}; kwargs...)
26+
alg = $f(blocktype(arrayt); kwargs...)
27+
return BlockPermutedDiagonalAlgorithm(alg)
28+
end
29+
end
4330
end
4431

45-
function MatrixAlgebraKit.eigh_full!(
46-
A::AbstractBlockSparseMatrix, (D, V), alg::BlockPermutedDiagonalAlgorithm
47-
)
48-
for I in blockdiagindices(A)
49-
d, v = eigh_full!(A[I], alg.alg)
50-
D[I] = d
51-
V[I] = v
32+
for f in [:eig_full!, :eigh_full!]
33+
@eval begin
34+
function MatrixAlgebraKit.initialize_output(
35+
::typeof($f), A::AbstractBlockSparseMatrix, alg::BlockPermutedDiagonalAlgorithm
36+
)
37+
return initialize_blocksparse_eig_output($f, A, alg)
38+
end
39+
function MatrixAlgebraKit.$f(
40+
A::AbstractBlockSparseMatrix, (D, V), alg::BlockPermutedDiagonalAlgorithm
41+
)
42+
return blocksparse_eig_full!($f, A, (D, V), alg)
43+
end
5244
end
53-
return (D, V)
5445
end

0 commit comments

Comments
 (0)