Skip to content

Commit a552990

Browse files
authored
Generalize block sparse matricize (#163)
1 parent 229248e commit a552990

File tree

12 files changed

+108
-23
lines changed

12 files changed

+108
-23
lines changed

Project.toml

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
name = "BlockSparseArrays"
22
uuid = "2c9a651f-6452-4ace-a6ac-809f4280fbb4"
33
authors = ["ITensor developers <[email protected]> and contributors"]
4-
version = "0.8.1"
4+
version = "0.9.0"
55

66
[deps]
77
Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e"
@@ -22,9 +22,11 @@ TypeParameterAccessors = "7e5a90cf-f82e-492e-a09b-e3e26432c138"
2222

2323
[weakdeps]
2424
TensorAlgebra = "68bd88dc-f39d-4e12-b2ca-f046b68fcc6a"
25+
TensorProducts = "decf83d6-1968-43f4-96dc-fdb3fe15fc6d"
2526

2627
[extensions]
2728
BlockSparseArraysTensorAlgebraExt = "TensorAlgebra"
29+
BlockSparseArraysTensorProductsExt = "TensorProducts"
2830

2931
[compat]
3032
Adapt = "4.1.1"
@@ -43,6 +45,7 @@ MatrixAlgebraKit = "0.2.2"
4345
SparseArraysBase = "0.7.1"
4446
SplitApplyCombine = "1.2.3"
4547
TensorAlgebra = "0.3.2"
48+
TensorProducts = "0.1.7"
4649
Test = "1.10"
4750
TypeParameterAccessors = "0.4.1"
4851
julia = "1.10"

docs/Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,6 @@ Literate = "98b081ad-f1c9-55d3-8b20-4c87d4299306"
66

77
[compat]
88
BlockArrays = "1"
9-
BlockSparseArrays = "0.8"
9+
BlockSparseArrays = "0.9"
1010
Documenter = "1"
1111
Literate = "2"

examples/Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,5 +5,5 @@ Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
55

66
[compat]
77
BlockArrays = "1"
8-
BlockSparseArrays = "0.8"
8+
BlockSparseArrays = "0.9"
99
Test = "1"

ext/BlockSparseArraysTensorAlgebraExt/BlockSparseArraysTensorAlgebraExt.jl

Lines changed: 36 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -15,19 +15,51 @@ function TensorAlgebra.FusionStyle(::Type{<:AbstractBlockSparseArray})
1515
return BlockReshapeFusion()
1616
end
1717

18+
using BlockArrays: Block, blocklength, blocks
19+
using BlockSparseArrays: blocksparse
20+
using SparseArraysBase: eachstoredindex
21+
using TensorAlgebra: TensorAlgebra, matricize, unmatricize
1822
function TensorAlgebra.matricize(
1923
::BlockReshapeFusion, a::AbstractArray, biperm::BlockedTrivialPermutation{2}
2024
)
21-
new_axes = fuseaxes(axes(a), biperm)
22-
return blockreshape(a, new_axes)
25+
ax = fuseaxes(axes(a), biperm)
26+
reshaped_blocks_a = reshape(blocks(a), map(blocklength, ax))
27+
key(I) = Block(Tuple(I))
28+
value(I) = matricize(reshaped_blocks_a[I], biperm)
29+
Is = eachstoredindex(reshaped_blocks_a)
30+
bs = if isempty(Is)
31+
# Catch empty case and make sure the type is constrained properly.
32+
# This seems to only be necessary in Julia versions below v1.11,
33+
# try removing it when we drop support for those versions.
34+
keytype = Base.promote_op(key, eltype(Is))
35+
valtype = Base.promote_op(value, eltype(Is))
36+
valtype′ = !isconcretetype(valtype) ? AbstractMatrix{eltype(a)} : valtype
37+
Dict{keytype,valtype′}()
38+
else
39+
Dict(key(I) => value(I) for I in Is)
40+
end
41+
return blocksparse(bs, ax)
2342
end
2443

44+
using BlockArrays: blocklengths
2545
function TensorAlgebra.unmatricize(
2646
::BlockReshapeFusion,
2747
m::AbstractMatrix,
28-
blocked_axes::BlockedTuple{2,<:Any,<:Tuple{Vararg{AbstractUnitRange}}},
48+
blocked_ax::BlockedTuple{2,<:Any,<:Tuple{Vararg{AbstractUnitRange}}},
2949
)
30-
return blockreshape(m, Tuple(blocked_axes)...)
50+
ax = Tuple(blocked_ax)
51+
reshaped_blocks_m = reshape(blocks(m), map(blocklength, ax))
52+
function f(I)
53+
block_axes_I = BlockedTuple(
54+
map(ntuple(identity, length(ax))) do i
55+
return Base.axes1(ax[i][Block(I[i])])
56+
end,
57+
blocklengths(blocked_ax),
58+
)
59+
return unmatricize(reshaped_blocks_m[I], block_axes_I)
60+
end
61+
bs = Dict(Block(Tuple(I)) => f(I) for I in eachstoredindex(reshaped_blocks_m))
62+
return blocksparse(bs, ax)
3163
end
3264

3365
end
Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,14 @@
1+
module BlockSparseArraysTensorProductsExt
2+
3+
using BlockSparseArrays: BlockUnitRange, blockrange, eachblockaxis
4+
using TensorProducts: TensorProducts, tensor_product
5+
# TODO: Dispatch on `FusionStyle` to allow different kinds of products,
6+
# for example to allow merging common symmetry sectors.
7+
function TensorProducts.tensor_product(a1::BlockUnitRange, a2::BlockUnitRange)
8+
new_blockaxes = vec(
9+
map(splat(tensor_product), Iterators.product(eachblockaxis(a1), eachblockaxis(a2)))
10+
)
11+
return blockrange(new_blockaxes)
12+
end
13+
14+
end

src/BlockArraysExtensions/BlockArraysExtensions.jl

Lines changed: 11 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -180,7 +180,10 @@ end
180180
# ```
181181
# but includes `BlockIndices`, where the blocks aren't contiguous.
182182
const BlockSliceCollection = Union{
183-
Base.Slice,BlockSlice{<:BlockRange{1}},BlockIndices{<:Vector{<:Block{1}}}
183+
Base.Slice,
184+
BlockSlice{<:Block{1}},
185+
BlockSlice{<:BlockRange{1}},
186+
BlockIndices{<:Vector{<:Block{1}}},
184187
}
185188
const BlockIndexRangeSlice = BlockSlice{
186189
<:BlockVector{<:BlockIndex{1},<:Vector{<:BlockIndexRange{1}}}
@@ -267,13 +270,15 @@ tuple_oneto(n) = ntuple(identity, n)
267270

268271
function _blockreshape(a::AbstractArray, axes::Tuple{Vararg{AbstractUnitRange}})
269272
reshaped_blocks_a = reshape(blocks(a), blocklength.(axes))
270-
reshaped_a = similar(a, axes)
271-
for I in eachstoredindex(reshaped_blocks_a)
272-
block_size_I = map(i -> length(axes[i][Block(I[i])]), tuple_oneto(length(axes)))
273+
function f(I)
274+
block_axes_I = map(ntuple(identity, length(axes))) do i
275+
return Base.axes1(axes[i][Block(I[i])])
276+
end
273277
# TODO: Better converter here.
274-
reshaped_a[Block(Tuple(I))] = reshape(reshaped_blocks_a[I], block_size_I)
278+
return reshape(reshaped_blocks_a[I], block_axes_I)
275279
end
276-
return reshaped_a
280+
bs = Dict(Block(Tuple(I)) => f(I) for I in eachstoredindex(reshaped_blocks_a))
281+
return blocksparse(bs, axes)
277282
end
278283

279284
function blockreshape(

src/BlockArraysExtensions/blockrange.jl

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,13 +8,21 @@ end
88
function blockrange(eachblockaxis)
99
return BlockUnitRange(blockedrange(length.(eachblockaxis)), eachblockaxis)
1010
end
11+
function blockrange(first::Integer, eachblockaxis)
12+
return BlockUnitRange(blockedrange(first, length.(eachblockaxis)), eachblockaxis)
13+
end
1114
Base.first(r::BlockUnitRange) = first(r.r)
1215
Base.last(r::BlockUnitRange) = last(r.r)
1316
BlockArrays.blocklasts(r::BlockUnitRange) = blocklasts(r.r)
1417
eachblockaxis(r::BlockUnitRange) = r.eachblockaxis
1518
function Base.getindex(r::BlockUnitRange, I::Block{1})
1619
return eachblockaxis(r)[Int(I)] .+ (first(r.r[I]) - 1)
1720
end
21+
function Base.getindex(r::BlockUnitRange, I::BlockRange{1})
22+
return blockrange(first(r), eachblockaxis(r)[Int.(I)])
23+
end
24+
Base.axes(r::BlockUnitRange) = (blockrange(eachblockaxis(r)),)
25+
Base.axes1(r::BlockUnitRange) = blockrange(eachblockaxis(r))
1826

1927
using BlockArrays: BlockedOneTo
2028
const BlockOneTo{T<:Integer,B,CS,R<:BlockedOneTo{T,CS}} = BlockUnitRange{T,B,CS,R}

src/abstractblocksparsearray/abstractblocksparsearray.jl

Lines changed: 13 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -87,12 +87,19 @@ function Base.setindex!(
8787
),
8888
)
8989
end
90-
# Custom `_convert` works around the issue that
91-
# `convert(::Type{<:Diagonal}, ::AbstractMatrix)` isnt' defined
92-
# in Julia v1.10 (https://github.com/JuliaLang/julia/pull/48895,
93-
# https://github.com/JuliaLang/julia/pull/52487).
94-
# TODO: Delete once we drop support for Julia v1.10.
95-
blocks(a)[Int.(I)...] = _convert(blocktype(a), value)
90+
if isstored(a, I...)
91+
# This writes into existing blocks, or constructs blocks
92+
# using the axes.
93+
aI = @view! a[I...]
94+
aI .= value
95+
else
96+
# Custom `_convert` works around the issue that
97+
# `convert(::Type{<:Diagonal}, ::AbstractMatrix)` isnt' defined
98+
# in Julia v1.10 (https://github.com/JuliaLang/julia/pull/48895,
99+
# https://github.com/JuliaLang/julia/pull/52487).
100+
# TODO: Delete `_convert` once we drop support for Julia v1.10.
101+
blocks(a)[Int.(I)...] = _convert(blocktype(a), value)
102+
end
96103
return a
97104
end
98105

src/blocksparsearray/blocksparsearray.jl

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -262,13 +262,18 @@ function blocksparsezeros(::BlockType{A}, axes...) where {A<:AbstractArray}
262262
# to make a bit more generic.
263263
return BlockSparseArray{eltype(A),ndims(A),A}(undef, axes...)
264264
end
265-
function blocksparse(d::Dict{<:Block,<:AbstractArray}, axes...)
266-
a = blocksparsezeros(BlockType(valtype(d)), axes...)
265+
function blocksparse(d::Dict{<:Block,<:AbstractArray}, ax::Tuple)
266+
a = blocksparsezeros(BlockType(valtype(d)), ax...)
267267
for I in eachindex(d)
268268
a[I] = d[I]
269269
end
270270
return a
271271
end
272+
function blocksparse(
273+
d::Dict{<:Block,<:AbstractArray}, blocklens::AbstractVector{<:Integer}...
274+
)
275+
return blocksparse(d, map(blockedrange, blocklens))
276+
end
272277

273278
# Base `AbstractArray` interface
274279
Base.axes(a::BlockSparseArray) = a.axes

src/blocksparsearrayinterface/blocksparsearrayinterface.jl

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -537,7 +537,11 @@ function SparseArraysBase.getunstoredindex(
537537
return error("Not implemented.")
538538
end
539539

540+
# Convert a blockwise slice on a block sparse array
541+
# to an elementwise slice on the corresponding sparse array
542+
# of blocks.
540543
to_blocks_indices(I::BlockSlice{<:BlockRange{1}}) = Int.(I.block)
544+
to_blocks_indices(I::BlockSlice{<:Block{1}}) = Int(I.block):Int(I.block)
541545
to_blocks_indices(I::BlockIndices{<:Vector{<:Block{1}}}) = Int.(I.blocks)
542546
to_blocks_indices(I::Base.Slice) = Base.OneTo(blocklength(I.indices))
543547

0 commit comments

Comments
 (0)