Skip to content

Commit 9acb157

Browse files
committed
[WIP] Block sparse eig
1 parent 6bee699 commit 9acb157

File tree

7 files changed

+72
-60
lines changed

7 files changed

+72
-60
lines changed

Project.toml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
name = "BlockSparseArrays"
22
uuid = "2c9a651f-6452-4ace-a6ac-809f4280fbb4"
33
authors = ["ITensor developers <[email protected]> and contributors"]
4-
version = "0.6.8"
4+
version = "0.6.9"
55

66
[deps]
77
Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e"
@@ -39,7 +39,7 @@ GPUArraysCore = "0.1.0, 0.2"
3939
LinearAlgebra = "1.10"
4040
MacroTools = "0.5.13"
4141
MapBroadcast = "0.1.5"
42-
MatrixAlgebraKit = "0.2"
42+
MatrixAlgebraKit = "0.2.2"
4343
SparseArraysBase = "0.5"
4444
SplitApplyCombine = "1.2.3"
4545
TensorAlgebra = "0.3.2"

src/BlockSparseArrays.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -51,5 +51,6 @@ include("factorizations/qr.jl")
5151
include("factorizations/lq.jl")
5252
include("factorizations/polar.jl")
5353
include("factorizations/orthnull.jl")
54+
include("factorizations/eig.jl")
5455

5556
end

src/factorizations/eig.jl

Lines changed: 56 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,56 @@
1+
using MatrixAlgebraKit:
2+
MatrixAlgebraKit, default_eig_algorithm, default_eigh_algorithm, eig_full!, eigh_full!
3+
4+
function MatrixAlgebraKit.default_eig_algorithm(
5+
arrayt::Type{<:AbstractBlockSparseMatrix}; kwargs...
6+
)
7+
alg = default_eig_algorithm(blocktype(arrayt); kwargs...)
8+
return BlockPermutedDiagonalAlgorithm(alg)
9+
end
10+
11+
function MatrixAlgebraKit.initialize_output(
12+
::typeof(eig_full!), A::AbstractBlockSparseMatrix, alg::BlockPermutedDiagonalAlgorithm
13+
)
14+
D = similar(A, complex(eltype(A)))
15+
V = similar(A, complex(eltype(A)))
16+
return (D, V)
17+
end
18+
19+
function MatrixAlgebraKit.eig_full!(
20+
A::AbstractBlockSparseMatrix, (D, V), alg::BlockPermutedDiagonalAlgorithm
21+
)
22+
for I in blockdiagindices(A)
23+
d, v = eig_full!(A[I], alg.alg)
24+
D[I] = d
25+
V[I] = v
26+
end
27+
return (D, V)
28+
end
29+
30+
# TODO: this is a hardcoded for now to get around this function not being defined in the
31+
# type domain
32+
function MatrixAlgebraKit.default_eigh_algorithm(
33+
arrayt::Type{<:AbstractBlockSparseMatrix}; kwargs...
34+
)
35+
alg = default_eigh_algorithm(blocktype(arrayt); kwargs...)
36+
return BlockPermutedDiagonalAlgorithm(alg)
37+
end
38+
39+
function MatrixAlgebraKit.initialize_output(
40+
::typeof(eigh_full!), A::AbstractBlockSparseMatrix, alg::BlockPermutedDiagonalAlgorithm
41+
)
42+
D = similar(A, complex(eltype(A)))
43+
V = similar(A, complex(eltype(A)))
44+
return (D, V)
45+
end
46+
47+
function MatrixAlgebraKit.eigh_full!(
48+
A::AbstractBlockSparseMatrix, (D, V), alg::BlockPermutedDiagonalAlgorithm
49+
)
50+
for I in blockdiagindices(A)
51+
d, v = eigh_full!(A[I], alg.alg)
52+
D[I] = d
53+
V[I] = v
54+
end
55+
return (D, V)
56+
end

src/factorizations/lq.jl

Lines changed: 5 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -1,22 +1,12 @@
1-
using MatrixAlgebraKit: MatrixAlgebraKit, lq_compact!, lq_full!
1+
using MatrixAlgebraKit: MatrixAlgebraKit, default_lq_algorithm, lq_compact!, lq_full!
22

33
# TODO: this is a hardcoded for now to get around this function not being defined in the
44
# type domain
5-
function default_blocksparse_lq_algorithm(A::AbstractMatrix; kwargs...)
6-
blocktype(A) <: StridedMatrix{<:LinearAlgebra.BLAS.BlasFloat} ||
7-
error("unsupported type: $(blocktype(A))")
8-
alg = MatrixAlgebraKit.LAPACK_HouseholderLQ(; kwargs...)
9-
return BlockPermutedDiagonalAlgorithm(alg)
10-
end
11-
function MatrixAlgebraKit.default_algorithm(
12-
::typeof(lq_compact!), A::AbstractBlockSparseMatrix; kwargs...
5+
function MatrixAlgebraKit.default_lq_algorithm(
6+
arrayt::Type{<:AbstractBlockSparseMatrix}; kwargs...
137
)
14-
return default_blocksparse_lq_algorithm(A; kwargs...)
15-
end
16-
function MatrixAlgebraKit.default_algorithm(
17-
::typeof(lq_full!), A::AbstractBlockSparseMatrix; kwargs...
18-
)
19-
return default_blocksparse_lq_algorithm(A; kwargs...)
8+
alg = default_lq_algorithm(blocktype(arrayt); kwargs...)
9+
return BlockPermutedDiagonalAlgorithm(alg)
2010
end
2111

2212
function similar_output(

src/factorizations/polar.jl

Lines changed: 0 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -46,14 +46,3 @@ function MatrixAlgebraKit.right_polar!(A::AbstractBlockSparseMatrix, alg::PolarV
4646
P = U * S * copy(U')
4747
return (P, Wᴴ)
4848
end
49-
50-
function MatrixAlgebraKit.default_algorithm(
51-
::typeof(left_polar!), a::AbstractBlockSparseMatrix; kwargs...
52-
)
53-
return PolarViaSVD(default_algorithm(svd_compact!, a; kwargs...))
54-
end
55-
function MatrixAlgebraKit.default_algorithm(
56-
::typeof(right_polar!), a::AbstractBlockSparseMatrix; kwargs...
57-
)
58-
return PolarViaSVD(default_algorithm(svd_compact!, a; kwargs...))
59-
end

src/factorizations/qr.jl

Lines changed: 4 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -3,21 +3,11 @@ using MatrixAlgebraKit:
33

44
# TODO: this is a hardcoded for now to get around this function not being defined in the
55
# type domain
6-
function MatrixAlgebraKit.default_qr_algorithm(A::AbstractBlockSparseMatrix; kwargs...)
7-
blocktype(A) <: StridedMatrix{<:LinearAlgebra.BLAS.BlasFloat} ||
8-
error("unsupported type: $(blocktype(A))")
9-
alg = MatrixAlgebraKit.LAPACK_HouseholderQR(; kwargs...)
10-
return BlockPermutedDiagonalAlgorithm(alg)
11-
end
12-
function MatrixAlgebraKit.default_algorithm(
13-
::typeof(qr_compact!), A::AbstractBlockSparseMatrix; kwargs...
6+
function MatrixAlgebraKit.default_qr_algorithm(
7+
arrayt::Type{<:AbstractBlockSparseMatrix}; kwargs...
148
)
15-
return default_qr_algorithm(A; kwargs...)
16-
end
17-
function MatrixAlgebraKit.default_algorithm(
18-
::typeof(qr_full!), A::AbstractBlockSparseMatrix; kwargs...
19-
)
20-
return default_qr_algorithm(A; kwargs...)
9+
alg = default_qr_algorithm(blocktype(arrayt); kwargs...)
10+
return BlockPermutedDiagonalAlgorithm(alg)
2111
end
2212

2313
function similar_output(

src/factorizations/svd.jl

Lines changed: 4 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -12,25 +12,11 @@ struct BlockPermutedDiagonalAlgorithm{A<:MatrixAlgebraKit.AbstractAlgorithm} <:
1212
alg::A
1313
end
1414

15-
function MatrixAlgebraKit.default_svd_algorithm(A; kwargs...)
16-
blocktype(A) <: StridedMatrix{<:LinearAlgebra.BLAS.BlasFloat} ||
17-
error("unsupported type: $(blocktype(A))")
18-
# TODO: this is a hardcoded for now to get around this function not being defined in the
19-
# type domain
20-
# alg = MatrixAlgebraKit.default_algorithm(f, blocktype(A); kwargs...)
21-
alg = MatrixAlgebraKit.LAPACK_DivideAndConquer(; kwargs...)
22-
return BlockPermutedDiagonalAlgorithm(alg)
23-
end
24-
25-
function MatrixAlgebraKit.default_algorithm(
26-
f::typeof(svd_compact!), A::AbstractBlockSparseMatrix; kwargs...
15+
function MatrixAlgebraKit.default_svd_algorithm(
16+
arrayt::Type{<:AbstractBlockSparseMatrix}; kwargs...
2717
)
28-
return default_svd_algorithm(A; kwargs...)
29-
end
30-
function MatrixAlgebraKit.default_algorithm(
31-
f::typeof(svd_full!), A::AbstractBlockSparseMatrix; kwargs...
32-
)
33-
return default_svd_algorithm(A; kwargs...)
18+
alg = default_svd_algorithm(blocktype(arrayt))
19+
return BlockPermutedDiagonalAlgorithm(alg)
3420
end
3521

3622
function similar_output(

0 commit comments

Comments
 (0)