Skip to content

Commit b092ce7

Browse files
authored
[BlockSparseArrays] Permute and merge blocks (#1514)
* [BlockSparseArrays] Permute and merge blocks * [NDTensors] Bump to v0.3.39
1 parent 57f3321 commit b092ce7

File tree

10 files changed

+552
-84
lines changed

10 files changed

+552
-84
lines changed

ext/BlockSparseArraysGradedAxesExt/src/BlockSparseArraysGradedAxesExt.jl

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -66,7 +66,12 @@ function TensorAlgebra.splitdims(
6666
return length(axis) length(axes(a, i))
6767
end
6868
blockperms = invblockperm.(blocksortperm.(axes_prod))
69-
a_blockpermed = a[blockperms...]
69+
# TODO: This is doing extra copies of the blocks,
70+
# use `@view a[axes_prod...]` instead.
71+
# That will require implementing some reindexing logic
72+
# for this combination of slicing.
73+
a_unblocked = a[axes_prod...]
74+
a_blockpermed = a_unblocked[blockperms...]
7075
return splitdims(BlockReshapeFusion(), a_blockpermed, split_axes...)
7176
end
7277

ext/BlockSparseArraysGradedAxesExt/test/runtests.jl

Lines changed: 2 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -87,14 +87,9 @@ const elts = (Float32, Float64, Complex{Float32}, Complex{Float64})
8787
a = BlockSparseArray{elt}(d1, d2, d1, d2)
8888
blockdiagonal!(randn!, a)
8989
m = fusedims(a, (1, 2), (3, 4))
90-
# TODO: Once block merging is implemented, this should
91-
# be the real test.
9290
for ax in axes(m)
9391
@test ax isa GradedOneTo
94-
# TODO: Current `fusedims` doesn't merge
95-
# common sectors, need to fix.
96-
@test_broken blocklabels(ax) == [U1(0), U1(1), U1(2)]
97-
@test blocklabels(ax) == [U1(0), U1(1), U1(1), U1(2)]
92+
@test blocklabels(ax) == [U1(0), U1(1), U1(2)]
9893
end
9994
for I in CartesianIndices(m)
10095
if I CartesianIndex.([(1, 1), (4, 4)])
@@ -105,10 +100,7 @@ const elts = (Float32, Float64, Complex{Float32}, Complex{Float64})
105100
end
106101
@test a[1, 1, 1, 1] == m[1, 1]
107102
@test a[2, 2, 2, 2] == m[4, 4]
108-
# TODO: Current `fusedims` doesn't merge
109-
# common sectors, need to fix.
110-
@test_broken blocksize(m) == (3, 3)
111-
@test blocksize(m) == (4, 4)
103+
@test blocksize(m) == (3, 3)
112104
@test a == splitdims(m, (d1, d2), (d1, d2))
113105
end
114106
@testset "dual axes" begin

src/BlockArraysExtensions/BlockArraysExtensions.jl

Lines changed: 124 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -3,11 +3,14 @@ using BlockArrays:
33
AbstractBlockArray,
44
AbstractBlockVector,
55
Block,
6+
BlockIndex,
7+
BlockIndexRange,
68
BlockRange,
9+
BlockSlice,
10+
BlockVector,
711
BlockedOneTo,
812
BlockedUnitRange,
9-
BlockVector,
10-
BlockSlice,
13+
BlockedVector,
1114
block,
1215
blockaxes,
1316
blockedrange,
@@ -17,8 +20,30 @@ using BlockArrays:
1720
findblockindex
1821
using Compat: allequal
1922
using Dictionaries: Dictionary, Indices
20-
using ..GradedAxes: blockedunitrange_getindices
21-
using ..SparseArrayInterface: stored_indices
23+
using ..GradedAxes: blockedunitrange_getindices, to_blockindices
24+
using ..SparseArrayInterface: SparseArrayInterface, nstored, stored_indices
25+
26+
# A return type for `blocks(array)` when `array` isn't blocked.
27+
# Represents a vector with just that single block.
28+
struct SingleBlockView{T,N,Array<:AbstractArray{T,N}} <: AbstractArray{T,N}
29+
array::Array
30+
end
31+
blocks_maybe_single(a) = blocks(a)
32+
blocks_maybe_single(a::Array) = SingleBlockView(a)
33+
function Base.getindex(a::SingleBlockView{<:Any,N}, index::Vararg{Int,N}) where {N}
34+
@assert all(isone, index)
35+
return a.array
36+
end
37+
38+
# A wrapper around a potentially blocked array that is not blocked.
39+
struct NonBlockedArray{T,N,Array<:AbstractArray{T,N}} <: AbstractArray{T,N}
40+
array::Array
41+
end
42+
Base.size(a::NonBlockedArray) = size(a.array)
43+
Base.getindex(a::NonBlockedArray{<:Any,N}, I::Vararg{Integer,N}) where {N} = a.array[I...]
44+
BlockArrays.blocks(a::NonBlockedArray) = SingleBlockView(a.array)
45+
const NonBlockedVector{T,Array} = NonBlockedArray{T,1,Array}
46+
NonBlockedVector(array::AbstractVector) = NonBlockedArray(array)
2247

2348
# BlockIndices works around an issue that the indices of BlockSlice
2449
# are restricted to AbstractUnitRange{Int}.
@@ -37,6 +62,43 @@ function Base.getindex(S::BlockIndices, i::BlockSlice{<:Block{1}})
3762
@assert length(S.indices[Block(i)]) == length(i.indices)
3863
return BlockSlice(S.blocks[Int(Block(i))], S.indices[Block(i)])
3964
end
65+
66+
# This is used in slicing like:
67+
# a = BlockSparseArray{Float64}([2, 2, 2, 2], [2, 2, 2, 2])
68+
# I = BlockedVector([Block(4), Block(3), Block(2), Block(1)], [2, 2])
69+
# a[I, I]
70+
function Base.getindex(
71+
S::BlockIndices{<:AbstractBlockVector{<:Block{1}}}, i::BlockSlice{<:Block{1}}
72+
)
73+
# TODO: Check for conistency of indices.
74+
# Wrapping the indices in `NonBlockedVector` reinterprets the blocked indices
75+
# as a single block, since the result shouldn't be blocked.
76+
return NonBlockedVector(BlockIndices(S.blocks[Block(i)], S.indices[Block(i)]))
77+
end
78+
function Base.getindex(
79+
S::BlockIndices{<:BlockedVector{<:Block{1},<:BlockRange{1}}}, i::BlockSlice{<:Block{1}}
80+
)
81+
return i
82+
end
83+
84+
# Used in indexing such as:
85+
# ```julia
86+
# a = BlockSparseArray{Float64}([2, 2, 2, 2], [2, 2, 2, 2])
87+
# I = BlockedVector([Block(4), Block(3), Block(2), Block(1)], [2, 2])
88+
# b = @view a[I, I]
89+
# @view b[Block(1, 1)[1:2, 2:2]]
90+
# ```
91+
# This is similar to the definition:
92+
# blocksparse_to_indices(a, inds, I::Tuple{UnitRange{<:Integer},Vararg{Any}})
93+
function Base.getindex(
94+
a::NonBlockedVector{<:Integer,<:BlockIndices}, I::UnitRange{<:Integer}
95+
)
96+
ax = only(axes(a.array.indices))
97+
brs = to_blockindices(ax, I)
98+
inds = blockedunitrange_getindices(ax, I)
99+
return NonBlockedVector(a.array[BlockSlice(brs, inds)])
100+
end
101+
40102
function Base.getindex(S::BlockIndices, i::BlockSlice{<:BlockRange{1}})
41103
# TODO: Check that `i.indices` is consistent with `S.indices`.
42104
# TODO: Turn this into a `blockedunitrange_getindices` definition.
@@ -50,6 +112,34 @@ function Base.getindex(S::BlockIndices, i::BlockSlice{<:BlockRange{1}})
50112
return BlockIndices(subblocks, subindices)
51113
end
52114

115+
# Used when performing slices like:
116+
# @views a[[Block(2), Block(1)]][2:4, 2:4]
117+
function Base.getindex(S::BlockIndices, i::BlockSlice{<:BlockVector{<:BlockIndex{1}}})
118+
subblocks = mortar(
119+
map(blocks(i.block)) do br
120+
return S.blocks[Int(Block(br))][only(br.indices)]
121+
end,
122+
)
123+
subindices = mortar(
124+
map(blocks(i.block)) do br
125+
S.indices[br]
126+
end,
127+
)
128+
return BlockIndices(subblocks, subindices)
129+
end
130+
131+
# Similar to the definition of `BlockArrays.BlockSlices`:
132+
# ```julia
133+
# const BlockSlices = Union{Base.Slice,BlockSlice{<:BlockRange{1}}}
134+
# ```
135+
# but includes `BlockIndices`, where the blocks aren't contiguous.
136+
const BlockSliceCollection = Union{
137+
Base.Slice,BlockSlice{<:BlockRange{1}},BlockIndices{<:Vector{<:Block{1}}}
138+
}
139+
const SubBlockSliceCollection = BlockIndices{
140+
<:BlockVector{<:BlockIndex{1},<:Vector{<:BlockIndexRange{1}}}
141+
}
142+
53143
# TODO: This is type piracy. This is used in `reindex` when making
54144
# views of blocks of sliced block arrays, for example:
55145
# ```julia
@@ -218,6 +308,12 @@ function blockrange(axis::AbstractUnitRange, r::UnitRange)
218308
return findblock(axis, first(r)):findblock(axis, last(r))
219309
end
220310

311+
# Occurs when slicing with `a[2:4, 2:4]`.
312+
function blockrange(axis::BlockedOneTo{<:Integer}, r::BlockedUnitRange{<:Integer})
313+
# TODO: Check the blocks are commensurate.
314+
return findblock(axis, first(r)):findblock(axis, last(r))
315+
end
316+
221317
function blockrange(axis::AbstractUnitRange, r::Int)
222318
## return findblock(axis, r)
223319
return error("Slicing with integer values isn't supported.")
@@ -241,14 +337,17 @@ function blockrange(axis::BlockedOneTo{<:Integer}, r::BlockedOneTo{<:Integer})
241337
return only(blockaxes(r))
242338
end
243339

244-
# This handles changing the blocking, for example:
340+
# This handles block merging:
245341
# a = BlockSparseArray{Float64}([2, 2, 2, 2], [2, 2, 2, 2])
342+
# I = BlockedVector(Block.(1:4), [2, 2])
343+
# I = BlockVector(Block.(1:4), [2, 2])
246344
# I = BlockedVector([Block(4), Block(3), Block(2), Block(1)], [2, 2])
345+
# I = BlockVector([Block(4), Block(3), Block(2), Block(1)], [2, 2])
247346
# a[I, I]
248-
# TODO: Generalize to `AbstractBlockedUnitRange` and `AbstractBlockVector`.
249-
function blockrange(axis::BlockedOneTo{<:Integer}, r::BlockVector{<:Integer})
250-
# TODO: Probably this is incorrect and should be something like:
251-
# return findblock(axis, first(r)):findblock(axis, last(r))
347+
function blockrange(axis::BlockedOneTo{<:Integer}, r::AbstractBlockVector{<:Block{1}})
348+
for b in r
349+
@assert b blockaxes(axis, 1)
350+
end
252351
return only(blockaxes(r))
253352
end
254353

@@ -287,6 +386,10 @@ function blockrange(axis::AbstractUnitRange, r::Base.Slice)
287386
return only(blockaxes(axis))
288387
end
289388

389+
function blockrange(axis::AbstractUnitRange, r::NonBlockedVector)
390+
return Block(1):Block(1)
391+
end
392+
290393
function blockrange(axis::AbstractUnitRange, r)
291394
return error("Slicing not implemented for range of type `$(typeof(r))`.")
292395
end
@@ -423,7 +526,18 @@ function Base.setindex!(a::BlockView{<:Any,N}, value, index::Vararg{Int,N}) wher
423526
return a
424527
end
425528

426-
function view!(a::BlockSparseArray{<:Any,N}, index::Block{N}) where {N}
529+
function SparseArrayInterface.nstored(a::BlockView)
530+
# TODO: Store whether or not the block is stored already as
531+
# a Bool in `BlockView`.
532+
I = CartesianIndex(Int.(a.block))
533+
# TODO: Use `block_stored_indices`.
534+
if I stored_indices(blocks(a.array))
535+
return nstored(blocks(a.array)[I])
536+
end
537+
return 0
538+
end
539+
540+
function view!(a::AbstractArray{<:Any,N}, index::Block{N}) where {N}
427541
return view!(a, Tuple(index)...)
428542
end
429543
function view!(a::AbstractArray{<:Any,N}, index::Vararg{Block{1},N}) where {N}

src/BlockSparseArrays.jl

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,11 @@
11
module BlockSparseArrays
2+
include("BlockArraysExtensions/BlockArraysExtensions.jl")
23
include("blocksparsearrayinterface/blocksparsearrayinterface.jl")
34
include("blocksparsearrayinterface/linearalgebra.jl")
45
include("blocksparsearrayinterface/blockzero.jl")
56
include("blocksparsearrayinterface/broadcast.jl")
67
include("blocksparsearrayinterface/arraylayouts.jl")
8+
include("blocksparsearrayinterface/views.jl")
79
include("abstractblocksparsearray/abstractblocksparsearray.jl")
810
include("abstractblocksparsearray/wrappedabstractblocksparsearray.jl")
911
include("abstractblocksparsearray/abstractblocksparsematrix.jl")
@@ -15,7 +17,6 @@ include("abstractblocksparsearray/broadcast.jl")
1517
include("abstractblocksparsearray/map.jl")
1618
include("blocksparsearray/defaults.jl")
1719
include("blocksparsearray/blocksparsearray.jl")
18-
include("BlockArraysExtensions/BlockArraysExtensions.jl")
1920
include("BlockArraysSparseArrayInterfaceExt/BlockArraysSparseArrayInterfaceExt.jl")
2021
include("../ext/BlockSparseArraysTensorAlgebraExt/src/BlockSparseArraysTensorAlgebraExt.jl")
2122
include("../ext/BlockSparseArraysGradedAxesExt/src/BlockSparseArraysGradedAxesExt.jl")

src/abstractblocksparsearray/map.jl

Lines changed: 40 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -25,19 +25,57 @@ end
2525
# This is type piracy, try to avoid this, maybe requires defining `map`.
2626
## Base.promote_shape(a1::Tuple{Vararg{BlockedUnitRange}}, a2::Tuple{Vararg{BlockedUnitRange}}) = combine_axes(a1, a2)
2727

28+
reblock(a) = a
29+
30+
# If the blocking of the slice doesn't match the blocking of the
31+
# parent array, reblock according to the blocking of the parent array.
32+
function reblock(
33+
a::SubArray{<:Any,<:Any,<:AbstractBlockSparseArray,<:Tuple{Vararg{AbstractUnitRange}}}
34+
)
35+
# TODO: This relies on the behavior that slicing a block sparse
36+
# array with a UnitRange inherits the blocking of the underlying
37+
# block sparse array, we might change that default behavior
38+
# so this might become something like `@blocked parent(a)[...]`.
39+
return @view parent(a)[UnitRange{Int}.(parentindices(a))...]
40+
end
41+
42+
function reblock(
43+
a::SubArray{<:Any,<:Any,<:AbstractBlockSparseArray,<:Tuple{Vararg{NonBlockedArray}}}
44+
)
45+
return @view parent(a)[map(I -> I.array, parentindices(a))...]
46+
end
47+
48+
function reblock(
49+
a::SubArray{
50+
<:Any,
51+
<:Any,
52+
<:AbstractBlockSparseArray,
53+
<:Tuple{Vararg{BlockIndices{<:AbstractBlockVector{<:Block{1}}}}},
54+
},
55+
)
56+
# Remove the blocking.
57+
return @view parent(a)[map(I -> Vector(I.blocks), parentindices(a))...]
58+
end
59+
60+
# TODO: Rewrite this so that it takes the blocking structure
61+
# made by combining the blocking of the axes (i.e. the blocking that
62+
# is used to determine `union_stored_blocked_cartesianindices(...)`).
63+
# `reblock` is a partial solution to that, but a bit ad-hoc.
64+
# TODO: Move to `blocksparsearrayinterface/map.jl`.
2865
function SparseArrayInterface.sparse_map!(
2966
::BlockSparseArrayStyle, f, a_dest::AbstractArray, a_srcs::Vararg{AbstractArray}
3067
)
68+
a_dest, a_srcs = reblock(a_dest), reblock.(a_srcs)
3169
for I in union_stored_blocked_cartesianindices(a_dest, a_srcs...)
3270
BI_dest = blockindexrange(a_dest, I)
3371
BI_srcs = map(a_src -> blockindexrange(a_src, I), a_srcs)
3472
# TODO: Investigate why this doesn't work:
3573
# block_dest = @view a_dest[_block(BI_dest)]
36-
block_dest = blocks(a_dest)[Int.(Tuple(_block(BI_dest)))...]
74+
block_dest = blocks_maybe_single(a_dest)[Int.(Tuple(_block(BI_dest)))...]
3775
# TODO: Investigate why this doesn't work:
3876
# block_srcs = ntuple(i -> @view(a_srcs[i][_block(BI_srcs[i])]), length(a_srcs))
3977
block_srcs = ntuple(length(a_srcs)) do i
40-
return blocks(a_srcs[i])[Int.(Tuple(_block(BI_srcs[i])))...]
78+
return blocks_maybe_single(a_srcs[i])[Int.(Tuple(_block(BI_srcs[i])))...]
4179
end
4280
subblock_dest = @view block_dest[BI_dest.indices...]
4381
subblock_srcs = ntuple(i -> @view(block_srcs[i][BI_srcs[i].indices...]), length(a_srcs))

0 commit comments

Comments
 (0)