Skip to content

Commit 78188cc

Browse files
authored
Block sparse eig (#129)
1 parent ea250bb commit 78188cc

File tree

9 files changed

+279
-38
lines changed

9 files changed

+279
-38
lines changed

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
name = "BlockSparseArrays"
22
uuid = "2c9a651f-6452-4ace-a6ac-809f4280fbb4"
33
authors = ["ITensor developers <[email protected]> and contributors"]
4-
version = "0.7.1"
4+
version = "0.7.2"
55

66
[deps]
77
Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e"

src/BlockSparseArrays.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -51,5 +51,6 @@ include("factorizations/qr.jl")
5151
include("factorizations/lq.jl")
5252
include("factorizations/polar.jl")
5353
include("factorizations/orthnull.jl")
54+
include("factorizations/eig.jl")
5455

5556
end

src/abstractblocksparsearray/wrappedabstractblocksparsearray.jl

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -348,6 +348,15 @@ function Base.similar(
348348
return @interface BlockSparseArrayInterface() similar(a, elt, axes)
349349
end
350350

351+
struct BlockType{T} end
352+
BlockType(x) = BlockType{x}()
353+
function Base.similar(a::AbstractBlockSparseArray, ::BlockType{T}, ax) where {T}
354+
return BlockSparseArray{eltype(T),ndims(T),T}(undef, ax)
355+
end
356+
function Base.similar(a::AbstractBlockSparseArray, T::BlockType)
357+
return similar(a, T, axes(a))
358+
end
359+
351360
# TODO: Implement this in a more generic way using a smarter `copyto!`,
352361
# which is ultimately what `Array{T,N}(::AbstractArray{<:Any,N})` calls.
353362
# These are defined for now to avoid scalar indexing issues when there

src/blocksparsearrayinterface/blocksparsearrayinterface.jl

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,9 @@ end
4545
function eachstoredblockdiagindex(a::AbstractArray)
4646
return eachblockstoredindex(a) blockdiagindices(a)
4747
end
48+
function eachunstoredblockdiagindex(a::AbstractArray)
49+
return setdiff(blockdiagindices(a), eachblockstoredindex(a))
50+
end
4851

4952
# Like `BlockArrays.eachblock` but only iterating
5053
# over stored blocks.

src/factorizations/eig.jl

Lines changed: 92 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,92 @@
1+
using BlockArrays: blocksizes
2+
using DiagonalArrays: diagonal
3+
using LinearAlgebra: LinearAlgebra, Diagonal
4+
using MatrixAlgebraKit:
5+
MatrixAlgebraKit,
6+
TruncationStrategy,
7+
check_input,
8+
default_eig_algorithm,
9+
default_eigh_algorithm,
10+
diagview,
11+
eig_full!,
12+
eig_trunc!,
13+
eig_vals!,
14+
eigh_full!,
15+
eigh_trunc!,
16+
eigh_vals!,
17+
findtruncated
18+
19+
for f in [:default_eig_algorithm, :default_eigh_algorithm]
20+
@eval begin
21+
function MatrixAlgebraKit.$f(arrayt::Type{<:AbstractBlockSparseMatrix}; kwargs...)
22+
alg = $f(blocktype(arrayt); kwargs...)
23+
return BlockPermutedDiagonalAlgorithm(alg)
24+
end
25+
end
26+
end
27+
28+
function MatrixAlgebraKit.check_input(
29+
::typeof(eig_full!), A::AbstractBlockSparseMatrix, (D, V)
30+
)
31+
@assert isa(D, AbstractBlockSparseMatrix) && isa(V, AbstractBlockSparseMatrix)
32+
@assert eltype(V) === eltype(D) === complex(eltype(A))
33+
@assert axes(A, 1) == axes(A, 2)
34+
@assert axes(A) == axes(D) == axes(V)
35+
return nothing
36+
end
37+
function MatrixAlgebraKit.check_input(
38+
::typeof(eigh_full!), A::AbstractBlockSparseMatrix, (D, V)
39+
)
40+
@assert isa(D, AbstractBlockSparseMatrix) && isa(V, AbstractBlockSparseMatrix)
41+
@assert eltype(V) === eltype(A)
42+
@assert eltype(D) === real(eltype(A))
43+
@assert axes(A, 1) == axes(A, 2)
44+
@assert axes(A) == axes(D) == axes(V)
45+
return nothing
46+
end
47+
48+
for f in [:eig_full!, :eigh_full!]
49+
@eval begin
50+
function MatrixAlgebraKit.initialize_output(
51+
::typeof($f), A::AbstractBlockSparseMatrix, alg::BlockPermutedDiagonalAlgorithm
52+
)
53+
Td, Tv = fieldtypes(Base.promote_op($f, blocktype(A), typeof(alg.alg)))
54+
D = similar(A, BlockType(Td))
55+
V = similar(A, BlockType(Tv))
56+
return (D, V)
57+
end
58+
function MatrixAlgebraKit.$f(
59+
A::AbstractBlockSparseMatrix, (D, V), alg::BlockPermutedDiagonalAlgorithm
60+
)
61+
check_input($f, A, (D, V))
62+
for I in eachstoredblockdiagindex(A)
63+
D[I], V[I] = $f(@view(A[I]), alg.alg)
64+
end
65+
for I in eachunstoredblockdiagindex(A)
66+
# TODO: Support setting `LinearAlgebra.I` directly, and/or
67+
# using `FillArrays.Eye`.
68+
V[I] = LinearAlgebra.I(size(@view(V[I]), 1))
69+
end
70+
return (D, V)
71+
end
72+
end
73+
end
74+
75+
for f in [:eig_vals!, :eigh_vals!]
76+
@eval begin
77+
function MatrixAlgebraKit.initialize_output(
78+
::typeof($f), A::AbstractBlockSparseMatrix, alg::BlockPermutedDiagonalAlgorithm
79+
)
80+
T = Base.promote_op($f, blocktype(A), typeof(alg.alg))
81+
return similar(A, BlockType(T), axes(A, 1))
82+
end
83+
function MatrixAlgebraKit.$f(
84+
A::AbstractBlockSparseMatrix, D, alg::BlockPermutedDiagonalAlgorithm
85+
)
86+
for I in eachblockstoredindex(A)
87+
D[I] = $f(@view!(A[I]), alg.alg)
88+
end
89+
return D
90+
end
91+
end
92+
end

src/factorizations/svd.jl

Lines changed: 22 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
1-
using MatrixAlgebraKit: MatrixAlgebraKit, default_svd_algorithm, svd_compact!, svd_full!
1+
using MatrixAlgebraKit:
2+
MatrixAlgebraKit, check_input, default_svd_algorithm, svd_compact!, svd_full!
23

34
"""
45
BlockPermutedDiagonalAlgorithm(A::MatrixAlgebraKit.AbstractAlgorithm)
@@ -152,45 +153,40 @@ function MatrixAlgebraKit.initialize_output(
152153
end
153154

154155
function MatrixAlgebraKit.check_input(
155-
::typeof(svd_compact!), A::AbstractBlockSparseMatrix, USVᴴ
156+
::typeof(svd_compact!), A::AbstractBlockSparseMatrix, (U, S, Vᴴ)
156157
)
157-
U, S, Vt = USVᴴ
158158
@assert isa(U, AbstractBlockSparseMatrix) &&
159159
isa(S, AbstractBlockSparseMatrix) &&
160-
isa(Vt, AbstractBlockSparseMatrix)
161-
@assert eltype(A) == eltype(U) == eltype(Vt)
160+
isa(Vᴴ, AbstractBlockSparseMatrix)
161+
@assert eltype(A) == eltype(U) == eltype(Vᴴ)
162162
@assert real(eltype(A)) == eltype(S)
163-
@assert axes(A, 1) == axes(U, 1) && axes(A, 2) == axes(Vt, 2)
163+
@assert axes(A, 1) == axes(U, 1) && axes(A, 2) == axes(Vᴴ, 2)
164164
@assert axes(S, 1) == axes(S, 2)
165-
166165
return nothing
167166
end
168167

169168
function MatrixAlgebraKit.check_input(
170-
::typeof(svd_full!), A::AbstractBlockSparseMatrix, USVᴴ
169+
::typeof(svd_full!), A::AbstractBlockSparseMatrix, (U, S, Vᴴ)
171170
)
172-
U, S, Vt = USVᴴ
173171
@assert isa(U, AbstractBlockSparseMatrix) &&
174172
isa(S, AbstractBlockSparseMatrix) &&
175-
isa(Vt, AbstractBlockSparseMatrix)
176-
@assert eltype(A) == eltype(U) == eltype(Vt)
173+
isa(Vᴴ, AbstractBlockSparseMatrix)
174+
@assert eltype(A) == eltype(U) == eltype(Vᴴ)
177175
@assert real(eltype(A)) == eltype(S)
178-
@assert axes(A, 1) == axes(U, 1) && axes(A, 2) == axes(Vt, 1) == axes(Vt, 2)
176+
@assert axes(A, 1) == axes(U, 1) && axes(A, 2) == axes(Vᴴ, 1) == axes(Vᴴ, 2)
179177
@assert axes(S, 2) == axes(A, 2)
180-
181178
return nothing
182179
end
183180

184181
function MatrixAlgebraKit.svd_compact!(
185-
A::AbstractBlockSparseMatrix, USVᴴ, alg::BlockPermutedDiagonalAlgorithm
182+
A::AbstractBlockSparseMatrix, (U, S, Vᴴ), alg::BlockPermutedDiagonalAlgorithm
186183
)
187-
MatrixAlgebraKit.check_input(svd_compact!, A, USVᴴ)
188-
U, S, Vt = USVᴴ
184+
check_input(svd_compact!, A, (U, S, Vᴴ))
189185

190186
# do decomposition on each block
191187
for bI in eachblockstoredindex(A)
192188
brow, bcol = Tuple(bI)
193-
usvᴴ = (@view!(U[brow, bcol]), @view!(S[bcol, bcol]), @view!(Vt[bcol, bcol]))
189+
usvᴴ = (@view!(U[brow, bcol]), @view!(S[bcol, bcol]), @view!(Vᴴ[bcol, bcol]))
194190
usvᴴ′ = svd_compact!(@view!(A[bI]), usvᴴ, alg.alg)
195191
@assert usvᴴ === usvᴴ′ "svd_compact! might not be in-place"
196192
end
@@ -203,25 +199,24 @@ function MatrixAlgebraKit.svd_compact!(
203199
emptycols = setdiff(1:blocksize(A, 2), bcolIs)
204200
# needs copyto! instead because size(::LinearAlgebra.I) doesn't work
205201
# U[Block(row, col)] = LinearAlgebra.I
206-
# Vt[Block(col, col)] = LinearAlgebra.I
202+
# Vᴴ[Block(col, col)] = LinearAlgebra.I
207203
for (row, col) in zip(emptyrows, emptycols)
208204
copyto!(@view!(U[Block(row, col)]), LinearAlgebra.I)
209-
copyto!(@view!(Vt[Block(col, col)]), LinearAlgebra.I)
205+
copyto!(@view!(Vᴴ[Block(col, col)]), LinearAlgebra.I)
210206
end
211207

212-
return USVᴴ
208+
return (U, S, Vᴴ)
213209
end
214210

215211
function MatrixAlgebraKit.svd_full!(
216-
A::AbstractBlockSparseMatrix, USVᴴ, alg::BlockPermutedDiagonalAlgorithm
212+
A::AbstractBlockSparseMatrix, (U, S, Vᴴ), alg::BlockPermutedDiagonalAlgorithm
217213
)
218-
MatrixAlgebraKit.check_input(svd_full!, A, USVᴴ)
219-
U, S, Vt = USVᴴ
214+
check_input(svd_full!, A, (U, S, Vᴴ))
220215

221216
# do decomposition on each block
222217
for bI in eachblockstoredindex(A)
223218
brow, bcol = Tuple(bI)
224-
usvᴴ = (@view!(U[brow, bcol]), @view!(S[bcol, bcol]), @view!(Vt[bcol, bcol]))
219+
usvᴴ = (@view!(U[brow, bcol]), @view!(S[bcol, bcol]), @view!(Vᴴ[bcol, bcol]))
225220
usvᴴ′ = svd_full!(@view!(A[bI]), usvᴴ, alg.alg)
226221
@assert usvᴴ === usvᴴ′ "svd_full! might not be in-place"
227222
end
@@ -237,17 +232,17 @@ function MatrixAlgebraKit.svd_full!(
237232
# Vt[Block(col, col)] = LinearAlgebra.I
238233
for (row, col) in zip(emptyrows, emptycols)
239234
copyto!(@view!(U[Block(row, col)]), LinearAlgebra.I)
240-
copyto!(@view!(Vt[Block(col, col)]), LinearAlgebra.I)
235+
copyto!(@view!(Vᴴ[Block(col, col)]), LinearAlgebra.I)
241236
end
242237

243238
# also handle extra rows/cols
244239
for i in (length(emptyrows) + 1):length(emptycols)
245-
copyto!(@view!(Vt[Block(emptycols[i], emptycols[i])]), LinearAlgebra.I)
240+
copyto!(@view!(Vᴴ[Block(emptycols[i], emptycols[i])]), LinearAlgebra.I)
246241
end
247242
bn = blocksize(A, 2)
248243
for (i, k) in enumerate((length(emptycols) + 1):length(emptyrows))
249244
copyto!(@view!(U[Block(emptyrows[k], bn + i)]), LinearAlgebra.I)
250245
end
251246

252-
return USVᴴ
247+
return (U, S, Vᴴ)
253248
end

src/factorizations/truncation.jl

Lines changed: 30 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
using MatrixAlgebraKit: TruncationStrategy, diagview, svd_trunc!
1+
using MatrixAlgebraKit: TruncationStrategy, diagview, eig_trunc!, eigh_trunc!, svd_trunc!
22

33
function MatrixAlgebraKit.diagview(A::BlockSparseMatrix{T,Diagonal{T,Vector{T}}}) where {T}
44
D = BlockSparseVector{T}(undef, axes(A, 1))
@@ -21,18 +21,29 @@ struct BlockPermutedDiagonalTruncationStrategy{T<:TruncationStrategy} <: Truncat
2121
strategy::T
2222
end
2323

24-
const TBlockUSVᴴ = Tuple{
25-
<:AbstractBlockSparseMatrix,<:AbstractBlockSparseMatrix,<:AbstractBlockSparseMatrix
26-
}
27-
2824
function MatrixAlgebraKit.truncate!(
29-
::typeof(svd_trunc!), (U, S, Vᴴ)::TBlockUSVᴴ, strategy::TruncationStrategy
25+
::typeof(svd_trunc!),
26+
(U, S, Vᴴ)::NTuple{3,AbstractBlockSparseMatrix},
27+
strategy::TruncationStrategy,
3028
)
3129
# TODO assert blockdiagonal
3230
return MatrixAlgebraKit.truncate!(
3331
svd_trunc!, (U, S, Vᴴ), BlockPermutedDiagonalTruncationStrategy(strategy)
3432
)
3533
end
34+
for f in [:eig_trunc!, :eigh_trunc!]
35+
@eval begin
36+
function MatrixAlgebraKit.truncate!(
37+
::typeof($f),
38+
(D, V)::NTuple{2,AbstractBlockSparseMatrix},
39+
strategy::TruncationStrategy,
40+
)
41+
return MatrixAlgebraKit.truncate!(
42+
$f, (D, V), BlockPermutedDiagonalTruncationStrategy(strategy)
43+
)
44+
end
45+
end
46+
end
3647

3748
# cannot use regular slicing here: I want to slice without altering blockstructure
3849
# solution: use boolean indexing and slice the mask, effectively cheaply inverting the map
@@ -47,9 +58,21 @@ end
4758

4859
function MatrixAlgebraKit.truncate!(
4960
::typeof(svd_trunc!),
50-
(U, S, Vᴴ)::TBlockUSV,
61+
(U, S, Vᴴ)::NTuple{3,AbstractBlockSparseMatrix},
5162
strategy::BlockPermutedDiagonalTruncationStrategy,
5263
)
5364
I = MatrixAlgebraKit.findtruncated(diagview(S), strategy)
5465
return (U[:, I], S[I, I], Vᴴ[I, :])
5566
end
67+
for f in [:eig_trunc!, :eigh_trunc!]
68+
@eval begin
69+
function MatrixAlgebraKit.truncate!(
70+
::typeof($f),
71+
(D, V)::NTuple{2,AbstractBlockSparseMatrix},
72+
strategy::BlockPermutedDiagonalTruncationStrategy,
73+
)
74+
I = MatrixAlgebraKit.findtruncated(diagview(D), strategy)
75+
return (D[I, I], V[:, I])
76+
end
77+
end
78+
end

test/Project.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@ MatrixAlgebraKit = "6c742aac-3347-4629-af66-fc926824e5e4"
1212
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
1313
SafeTestsets = "1bc83da4-3b8d-516f-aca4-4fe02f6d838f"
1414
SparseArraysBase = "0d5efcca-f356-4864-8770-e1ed8d78f208"
15+
StableRNGs = "860ef19b-820b-49d6-a774-d7a799459cd3"
1516
Suppressor = "fd094767-a336-5f1f-9728-57cf17d0bbfb"
1617
TensorAlgebra = "68bd88dc-f39d-4e12-b2ca-f046b68fcc6a"
1718
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"

0 commit comments

Comments
 (0)