Skip to content

Commit 8cfd7f9

Browse files
authored
Blockdiagonal factorizations refactor (#166)
1 parent 31c97f4 commit 8cfd7f9

File tree

14 files changed

+574
-604
lines changed

14 files changed

+574
-604
lines changed

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
name = "BlockSparseArrays"
22
uuid = "2c9a651f-6452-4ace-a6ac-809f4280fbb4"
33
authors = ["ITensor developers <[email protected]> and contributors"]
4-
version = "0.9.4"
4+
version = "0.10.0"
55

66
[deps]
77
Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e"

docs/Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,6 @@ Literate = "98b081ad-f1c9-55d3-8b20-4c87d4299306"
66

77
[compat]
88
BlockArrays = "1"
9-
BlockSparseArrays = "0.9"
9+
BlockSparseArrays = "0.10"
1010
Documenter = "1"
1111
Literate = "2"

examples/Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,5 +5,5 @@ Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
55

66
[compat]
77
BlockArrays = "1"
8-
BlockSparseArrays = "0.9"
8+
BlockSparseArrays = "0.10"
99
Test = "1"

src/BlockSparseArrays.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -45,7 +45,7 @@ include("blocksparsearray/blockdiagonalarray.jl")
4545
include("BlockArraysSparseArraysBaseExt/BlockArraysSparseArraysBaseExt.jl")
4646

4747
# factorizations
48-
include("factorizations/tensorproducts.jl")
48+
include("factorizations/utility.jl")
4949
include("factorizations/svd.jl")
5050
include("factorizations/truncation.jl")
5151
include("factorizations/qr.jl")

src/factorizations/eig.jl

Lines changed: 52 additions & 43 deletions
Original file line numberDiff line numberDiff line change
@@ -1,85 +1,82 @@
11
using BlockArrays: blocksizes
22
using DiagonalArrays: diagonal
33
using LinearAlgebra: LinearAlgebra, Diagonal
4-
using MatrixAlgebraKit:
5-
MatrixAlgebraKit,
6-
TruncationStrategy,
7-
check_input,
8-
default_eig_algorithm,
9-
default_eigh_algorithm,
10-
diagview,
11-
eig_full!,
12-
eig_trunc!,
13-
eig_vals!,
14-
eigh_full!,
15-
eigh_trunc!,
16-
eigh_vals!,
17-
findtruncated
4+
using MatrixAlgebraKit: MatrixAlgebraKit, diagview
5+
using MatrixAlgebraKit: default_eig_algorithm, eig_full!, eig_vals!
6+
using MatrixAlgebraKit: default_eigh_algorithm, eigh_full!, eigh_vals!
187

198
for f in [:default_eig_algorithm, :default_eigh_algorithm]
209
@eval begin
2110
function MatrixAlgebraKit.$f(::Type{<:AbstractBlockSparseMatrix}; kwargs...)
22-
return BlockPermutedDiagonalAlgorithm() do block
11+
return BlockDiagonalAlgorithm() do block
2312
return $f(block; kwargs...)
2413
end
2514
end
2615
end
2716
end
2817

18+
function output_type(::typeof(eig_full!), A::Type{<:AbstractMatrix{T}}) where {T}
19+
DV = Base.promote_op(eig_full!, A)
20+
return if isconcretetype(DV)
21+
DV
22+
else
23+
Tuple{AbstractMatrix{complex(T)},AbstractMatrix{complex(T)}}
24+
end
25+
end
26+
function output_type(::typeof(eigh_full!), A::Type{<:AbstractMatrix{T}}) where {T}
27+
DV = Base.promote_op(eigh_full!, A)
28+
return isconcretetype(DV) ? DV : Tuple{AbstractMatrix{real(T)},AbstractMatrix{T}}
29+
end
30+
2931
function MatrixAlgebraKit.check_input(
30-
::typeof(eig_full!), A::AbstractBlockSparseMatrix, (D, V)
32+
::typeof(eig_full!), A::AbstractBlockSparseMatrix, (D, V), ::BlockDiagonalAlgorithm
3133
)
3234
@assert isa(D, AbstractBlockSparseMatrix) && isa(V, AbstractBlockSparseMatrix)
3335
@assert eltype(V) === eltype(D) === complex(eltype(A))
3436
@assert axes(A, 1) == axes(A, 2)
3537
@assert axes(A) == axes(D) == axes(V)
38+
@assert isblockdiagonal(A)
3639
return nothing
3740
end
3841
function MatrixAlgebraKit.check_input(
39-
::typeof(eigh_full!), A::AbstractBlockSparseMatrix, (D, V)
42+
::typeof(eigh_full!), A::AbstractBlockSparseMatrix, (D, V), ::BlockDiagonalAlgorithm
4043
)
4144
@assert isa(D, AbstractBlockSparseMatrix) && isa(V, AbstractBlockSparseMatrix)
4245
@assert eltype(V) === eltype(A)
4346
@assert eltype(D) === real(eltype(A))
4447
@assert axes(A, 1) == axes(A, 2)
4548
@assert axes(A) == axes(D) == axes(V)
49+
@assert isblockdiagonal(A)
4650
return nothing
4751
end
4852

49-
function output_type(f::typeof(eig_full!), A::Type{<:AbstractMatrix{T}}) where {T}
50-
DV = Base.promote_op(f, A)
51-
!isconcretetype(DV) && return Tuple{AbstractMatrix{complex(T)},AbstractMatrix{complex(T)}}
52-
return DV
53-
end
54-
function output_type(f::typeof(eigh_full!), A::Type{<:AbstractMatrix{T}}) where {T}
55-
DV = Base.promote_op(f, A)
56-
!isconcretetype(DV) && return Tuple{AbstractMatrix{real(T)},AbstractMatrix{T}}
57-
return DV
58-
end
59-
6053
for f in [:eig_full!, :eigh_full!]
6154
@eval begin
6255
function MatrixAlgebraKit.initialize_output(
63-
::typeof($f), A::AbstractBlockSparseMatrix, alg::BlockPermutedDiagonalAlgorithm
56+
::typeof($f), A::AbstractBlockSparseMatrix, alg::BlockDiagonalAlgorithm
6457
)
6558
Td, Tv = fieldtypes(output_type($f, blocktype(A)))
6659
D = similar(A, BlockType(Td))
6760
V = similar(A, BlockType(Tv))
6861
return (D, V)
6962
end
7063
function MatrixAlgebraKit.$f(
71-
A::AbstractBlockSparseMatrix, (D, V), alg::BlockPermutedDiagonalAlgorithm
64+
A::AbstractBlockSparseMatrix, (D, V), alg::BlockDiagonalAlgorithm
7265
)
73-
check_input($f, A, (D, V))
74-
for I in eachstoredblockdiagindex(A)
75-
block = @view!(A[I])
76-
block_alg = block_algorithm(alg, block)
77-
D[I], V[I] = $f(block, block_alg)
78-
end
79-
for I in eachunstoredblockdiagindex(A)
80-
# TODO: Support setting `LinearAlgebra.I` directly, and/or
81-
# using `FillArrays.Eye`.
82-
V[I] = LinearAlgebra.I(size(@view(V[I]), 1))
66+
MatrixAlgebraKit.check_input($f, A, (D, V), alg)
67+
68+
# do decomposition on each block
69+
for bI in blockdiagindices(A)
70+
if isstored(A, bI)
71+
block = @view!(A[bI])
72+
block_alg = block_algorithm(alg, block)
73+
bD, bV = $f(block, block_alg)
74+
D[bI] = bD
75+
V[bI] = bV
76+
else
77+
# TODO: this should be `V[bI] = LinearAlgebra.I`
78+
copyto!(@view!(V[bI]), LinearAlgebra.I)
79+
end
8380
end
8481
return (D, V)
8582
end
@@ -100,17 +97,29 @@ end
10097
for f in [:eig_vals!, :eigh_vals!]
10198
@eval begin
10299
function MatrixAlgebraKit.initialize_output(
103-
::typeof($f), A::AbstractBlockSparseMatrix, alg::BlockPermutedDiagonalAlgorithm
100+
::typeof($f), A::AbstractBlockSparseMatrix, alg::BlockDiagonalAlgorithm
104101
)
105102
T = output_type($f, blocktype(A))
106103
return similar(A, BlockType(T), axes(A, 1))
107104
end
105+
function MatrixAlgebraKit.check_input(
106+
::typeof($f), A::AbstractBlockSparseMatrix, D, ::BlockDiagonalAlgorithm
107+
)
108+
@assert isa(D, AbstractBlockSparseVector)
109+
@assert eltype(D) === $(f == :eig_vals! ? complex : real)(eltype(A))
110+
@assert axes(A, 1) == axes(A, 2)
111+
@assert (axes(A, 1),) == axes(D)
112+
@assert isblockdiagonal(A)
113+
return nothing
114+
end
115+
108116
function MatrixAlgebraKit.$f(
109-
A::AbstractBlockSparseMatrix, D, alg::BlockPermutedDiagonalAlgorithm
117+
A::AbstractBlockSparseMatrix, D, alg::BlockDiagonalAlgorithm
110118
)
119+
MatrixAlgebraKit.check_input($f, A, D, alg)
111120
for I in eachblockstoredindex(A)
112121
block = @view!(A[I])
113-
D[I] = $f(block, block_algorithm(alg, block))
122+
D[Tuple(I)[1]] = $f(block, block_algorithm(alg, block))
114123
end
115124
return D
116125
end

0 commit comments

Comments
 (0)