Skip to content

Commit 243e917

Browse files
authored
Construct BlockSparseArray when slicing with graded unit ranges (#36)
1 parent 37fecf7 commit 243e917

File tree

9 files changed

+277
-92
lines changed

9 files changed

+277
-92
lines changed

Project.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@ LabelledNumbers = "f856a3a6-4152-4ec4-b2a7-02c1a55d7993"
2525
TensorAlgebra = "68bd88dc-f39d-4e12-b2ca-f046b68fcc6a"
2626

2727
[extensions]
28+
BlockSparseArraysGradedUnitRangesExt = "GradedUnitRanges"
2829
BlockSparseArraysTensorAlgebraExt = ["LabelledNumbers", "TensorAlgebra"]
2930

3031
[compat]
Lines changed: 82 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,82 @@
1+
module BlockSparseArraysGradedUnitRangesExt
2+
3+
using BlockSparseArrays: AnyAbstractBlockSparseArray, BlockSparseArray, blocktype
4+
using GradedUnitRanges: AbstractGradedUnitRange
5+
using TypeParameterAccessors: set_ndims, unwrap_array_type
6+
7+
# A block spare array similar to the input (dense) array.
8+
# TODO: Make `BlockSparseArrays.blocksparse_similar` more general and use that,
9+
# and also turn it into an DerivableInterfaces.jl-based interface function.
10+
function similar_blocksparse(
11+
a::AbstractArray,
12+
elt::Type,
13+
axes::Tuple{AbstractGradedUnitRange,Vararg{AbstractGradedUnitRange}},
14+
)
15+
# TODO: Probably need to unwrap the type of `a` in certain cases
16+
# to make a proper block type.
17+
return BlockSparseArray{
18+
elt,length(axes),set_ndims(unwrap_array_type(blocktype(a)), length(axes))
19+
}(
20+
axes
21+
)
22+
end
23+
24+
function Base.similar(
25+
a::AbstractArray,
26+
elt::Type,
27+
axes::Tuple{AbstractGradedUnitRange,Vararg{AbstractGradedUnitRange}},
28+
)
29+
return similar_blocksparse(a, elt, axes)
30+
end
31+
32+
# Fix ambiguity error with `BlockArrays.jl`.
33+
function Base.similar(
34+
a::StridedArray,
35+
elt::Type,
36+
axes::Tuple{
37+
AbstractGradedUnitRange,AbstractGradedUnitRange,Vararg{AbstractGradedUnitRange}
38+
},
39+
)
40+
return similar_blocksparse(a, elt, axes)
41+
end
42+
43+
# Fix ambiguity error with `BlockSparseArrays.jl`.
44+
function Base.similar(
45+
a::AnyAbstractBlockSparseArray,
46+
elt::Type,
47+
axes::Tuple{
48+
AbstractGradedUnitRange,AbstractGradedUnitRange,Vararg{AbstractGradedUnitRange}
49+
},
50+
)
51+
return similar_blocksparse(a, elt, axes)
52+
end
53+
54+
function getindex_blocksparse(a::AbstractArray, I::AbstractUnitRange...)
55+
a′ = similar(a, only.(axes.(I))...)
56+
a′ .= a
57+
return a′
58+
end
59+
60+
function Base.getindex(
61+
a::AbstractArray, I1::AbstractGradedUnitRange, I_rest::AbstractGradedUnitRange...
62+
)
63+
return getindex_blocksparse(a, I1, I_rest...)
64+
end
65+
66+
# Fix ambiguity errors.
67+
function Base.getindex(
68+
a::AnyAbstractBlockSparseArray,
69+
I1::AbstractGradedUnitRange,
70+
I_rest::AbstractGradedUnitRange...,
71+
)
72+
return getindex_blocksparse(a, I1, I_rest...)
73+
end
74+
function Base.getindex(
75+
a::AnyAbstractBlockSparseArray{<:Any,2},
76+
I1::AbstractGradedUnitRange,
77+
I2::AbstractGradedUnitRange,
78+
)
79+
return getindex_blocksparse(a, I1, I2)
80+
end
81+
82+
end

src/abstractblocksparsearray/broadcast.jl

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -46,3 +46,23 @@ function Broadcast.BroadcastStyle(
4646
)
4747
return BlockSparseArrayStyle{ndims(arraytype)}()
4848
end
49+
50+
# These catch cases that aren't caught by the standard
51+
# `BlockSparseArrayStyle` definition, and also fix
52+
# ambiguity issues.
53+
function Base.copyto!(dest::AnyAbstractBlockSparseArray, bc::Broadcasted)
54+
copyto_blocksparse!(dest, bc)
55+
return dest
56+
end
57+
function Base.copyto!(
58+
dest::AnyAbstractBlockSparseArray, bc::Broadcasted{<:Base.Broadcast.AbstractArrayStyle{0}}
59+
)
60+
copyto_blocksparse!(dest, bc)
61+
return dest
62+
end
63+
function Base.copyto!(
64+
dest::AnyAbstractBlockSparseArray{<:Any,N}, bc::Broadcasted{BlockSparseArrayStyle{N}}
65+
) where {N}
66+
copyto_blocksparse!(dest, bc)
67+
return dest
68+
end

src/abstractblocksparsearray/cat.jl

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,9 @@
11
using DerivableInterfaces: @interface, interface
22

3-
# TODO: Define with `@derive`.
4-
function Base.cat(as::AnyAbstractBlockSparseArray...; dims)
3+
function Base._cat(dims, as::AnyAbstractBlockSparseArray...)
4+
# TODO: Call `DerivableInterfaces.cat_along(dims, as...)` instead,
5+
# for better inferability. See:
6+
# https://github.com/ITensor/DerivableInterfaces.jl/pull/13
7+
# https://github.com/ITensor/DerivableInterfaces.jl/pull/17
58
return @interface interface(as...) cat(as...; dims)
69
end

src/abstractblocksparsearray/map.jl

Lines changed: 11 additions & 85 deletions
Original file line numberDiff line numberDiff line change
@@ -1,25 +1,8 @@
11
using ArrayLayouts: LayoutArray
2-
using BlockArrays: blockisequal
3-
using DerivableInterfaces: @interface, AbstractArrayInterface, interface
4-
using GPUArraysCore: @allowscalar
2+
using BlockArrays: AbstractBlockVector, Block
53
using LinearAlgebra: Adjoint, Transpose
6-
using SparseArraysBase: SparseArraysBase, SparseArrayStyle
7-
8-
# Returns `Vector{<:CartesianIndices}`
9-
function union_stored_blocked_cartesianindices(as::Vararg{AbstractArray})
10-
combined_axes = combine_axes(axes.(as)...)
11-
stored_blocked_cartesianindices_as = map(as) do a
12-
return blocked_cartesianindices(axes(a), combined_axes, eachblockstoredindex(a))
13-
end
14-
return (stored_blocked_cartesianindices_as...)
15-
end
16-
17-
# This is used by `map` to get the output axes.
18-
# This is type piracy, try to avoid this, maybe requires defining `map`.
19-
## Base.promote_shape(a1::Tuple{Vararg{BlockedUnitRange}}, a2::Tuple{Vararg{BlockedUnitRange}}) = combine_axes(a1, a2)
20-
21-
reblock(a) = a
224

5+
# TODO: Make this more general, independent of `AbstractBlockSparseArray`.
236
# If the blocking of the slice doesn't match the blocking of the
247
# parent array, reblock according to the blocking of the parent array.
258
function reblock(
@@ -32,12 +15,14 @@ function reblock(
3215
return @view parent(a)[UnitRange{Int}.(parentindices(a))...]
3316
end
3417

18+
# TODO: Make this more general, independent of `AbstractBlockSparseArray`.
3519
function reblock(
3620
a::SubArray{<:Any,<:Any,<:AbstractBlockSparseArray,<:Tuple{Vararg{NonBlockedArray}}}
3721
)
3822
return @view parent(a)[map(I -> I.array, parentindices(a))...]
3923
end
4024

25+
# TODO: Make this more general, independent of `AbstractBlockSparseArray`.
4126
function reblock(
4227
a::SubArray{
4328
<:Any,
@@ -50,77 +35,18 @@ function reblock(
5035
return @view parent(a)[map(I -> Vector(I.blocks), parentindices(a))...]
5136
end
5237

53-
# `map!` specialized to zero-dimensional inputs.
54-
function map_zero_dim! end
55-
56-
@interface ::AbstractArrayInterface function map_zero_dim!(
57-
f, a_dest::AbstractArray, a_srcs::AbstractArray...
58-
)
59-
@allowscalar a_dest[] = f.(map(a_src -> a_src[], a_srcs)...)
38+
function Base.map!(f, a_dest::AbstractArray, a_srcs::AnyAbstractBlockSparseArray...)
39+
@interface interface(a_dest, a_srcs...) map!(f, a_dest, a_srcs...)
6040
return a_dest
6141
end
62-
63-
# TODO: Move to `blocksparsearrayinterface/map.jl`.
64-
# TODO: Rewrite this so that it takes the blocking structure
65-
# made by combining the blocking of the axes (i.e. the blocking that
66-
# is used to determine `union_stored_blocked_cartesianindices(...)`).
67-
# `reblock` is a partial solution to that, but a bit ad-hoc.
68-
## TODO: Make this an `@interface AbstractBlockSparseArrayInterface` function.
69-
@interface interface::AbstractBlockSparseArrayInterface function Base.map!(
70-
f, a_dest::AbstractArray, a_srcs::AbstractArray...
71-
)
72-
if iszero(ndims(a_dest))
73-
@interface interface map_zero_dim!(f, a_dest, a_srcs...)
74-
return a_dest
75-
end
76-
77-
a_dest, a_srcs = reblock(a_dest), reblock.(a_srcs)
78-
for I in union_stored_blocked_cartesianindices(a_dest, a_srcs...)
79-
BI_dest = blockindexrange(a_dest, I)
80-
BI_srcs = map(a_src -> blockindexrange(a_src, I), a_srcs)
81-
# TODO: Investigate why this doesn't work:
82-
# block_dest = @view a_dest[_block(BI_dest)]
83-
block_dest = blocks_maybe_single(a_dest)[Int.(Tuple(_block(BI_dest)))...]
84-
# TODO: Investigate why this doesn't work:
85-
# block_srcs = ntuple(i -> @view(a_srcs[i][_block(BI_srcs[i])]), length(a_srcs))
86-
block_srcs = ntuple(length(a_srcs)) do i
87-
return blocks_maybe_single(a_srcs[i])[Int.(Tuple(_block(BI_srcs[i])))...]
88-
end
89-
subblock_dest = @view block_dest[BI_dest.indices...]
90-
subblock_srcs = ntuple(i -> @view(block_srcs[i][BI_srcs[i].indices...]), length(a_srcs))
91-
# TODO: Use `map!!` to handle immutable blocks.
92-
map!(f, subblock_dest, subblock_srcs...)
93-
# Replace the entire block, handles initializing new blocks
94-
# or if blocks are immutable.
95-
blocks(a_dest)[Int.(Tuple(_block(BI_dest)))...] = block_dest
96-
end
42+
function Base.map!(f, a_dest::AnyAbstractBlockSparseArray, a_srcs::AbstractArray...)
43+
@interface interface(a_dest, a_srcs...) map!(f, a_dest, a_srcs...)
9744
return a_dest
9845
end
99-
100-
# TODO: Move to `blocksparsearrayinterface/map.jl`.
101-
@interface ::AbstractBlockSparseArrayInterface function Base.mapreduce(
102-
f, op, as::AbstractArray...; kwargs...
46+
function Base.map!(
47+
f, a_dest::AnyAbstractBlockSparseArray, a_srcs::AnyAbstractBlockSparseArray...
10348
)
104-
# TODO: Define an `init` value based on the element type.
105-
return @interface interface(blocks.(as)...) mapreduce(
106-
block -> mapreduce(f, op, block), op, blocks.(as)...; kwargs...
107-
)
108-
end
109-
110-
# TODO: Move to `blocksparsearrayinterface/map.jl`.
111-
@interface ::AbstractBlockSparseArrayInterface function Base.iszero(a::AbstractArray)
112-
# TODO: Just call `iszero(blocks(a))`?
113-
return @interface interface(blocks(a)) iszero(blocks(a))
114-
end
115-
116-
# TODO: Move to `blocksparsearrayinterface/map.jl`.
117-
@interface ::AbstractBlockSparseArrayInterface function Base.isreal(a::AbstractArray)
118-
# TODO: Just call `isreal(blocks(a))`?
119-
return @interface interface(blocks(a)) isreal(blocks(a))
120-
end
121-
122-
function Base.map!(f, a_dest::AbstractArray, a_srcs::AnyAbstractBlockSparseArray...)
123-
@interface interface(a_srcs...) map!(f, a_dest, a_srcs...)
49+
@interface interface(a_dest, a_srcs...) map!(f, a_dest, a_srcs...)
12450
return a_dest
12551
end
12652

src/blocksparsearrayinterface/blocksparsearrayinterface.jl

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@ using BlockArrays:
1111
BlockedVector,
1212
block,
1313
blockcheckbounds,
14+
blockisequal,
1415
blocklengths,
1516
blocks,
1617
findblockindex
@@ -40,6 +41,31 @@ function eachstoredblock(a::AbstractArray)
4041
return storedvalues(blocks(a))
4142
end
4243

44+
# TODO: Generalize this, this catches simple cases
45+
# where the more general definition isn't specific enough.
46+
blocktype(a::Array) = typeof(a)
47+
# TODO: Maybe unwrap SubArrays?
48+
function blocktype(a::AbstractArray)
49+
# TODO: Unfortunately, this doesn't always give
50+
# a concrete type, even when it could be concrete, i.e.
51+
#=
52+
```julia
53+
julia> eltype(blocks(BlockArray(randn(2, 2), [1, 1], [1, 1])))
54+
Matrix{Float64} (alias for Array{Float64, 2})
55+
56+
julia> eltype(blocks(BlockedArray(randn(2, 2), [1, 1], [1, 1])))
57+
AbstractMatrix{Float64} (alias for AbstractArray{Float64, 2})
58+
59+
julia> eltype(blocks(randn(2, 2)))
60+
AbstractMatrix{Float64} (alias for AbstractArray{Float64, 2})
61+
```
62+
=#
63+
if isempty(blocks(a))
64+
return eltype(blocks(a))
65+
end
66+
return eltype(first(blocks(a)))
67+
end
68+
4369
abstract type AbstractBlockSparseArrayInterface <: AbstractSparseArrayInterface end
4470

4571
# TODO: Also support specifying the `blocktype` along with the `eltype`.

src/blocksparsearrayinterface/broadcast.jl

Lines changed: 19 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
using Base.Broadcast: BroadcastStyle, AbstractArrayStyle, DefaultArrayStyle, Broadcasted
2+
using GPUArraysCore: @allowscalar
23
using MapBroadcast: Mapped
34
using DerivableInterfaces: DerivableInterfaces, @interface
45

@@ -36,14 +37,28 @@ function Base.similar(bc::Broadcasted{<:BlockSparseArrayStyle}, elt::Type)
3637
return similar(first(m.args), elt, combine_axes(axes.(m.args)...))
3738
end
3839

40+
# Catches cases like `dest .= value` or `dest .= value1 .+ value2`.
41+
# If the RHS is zero, this makes sure that the storage is emptied,
42+
# which is logic that is handled by `fill!`.
43+
function copyto_blocksparse!(dest::AbstractArray, bc::Broadcasted{<:AbstractArrayStyle{0}})
44+
# `[]` is used to unwrap zero-dimensional arrays.
45+
value = @allowscalar bc.f(bc.args...)[]
46+
return @interface BlockSparseArrayInterface() fill!(dest, value)
47+
end
48+
3949
# Broadcasting implementation
4050
# TODO: Delete this in favor of `DerivableInterfaces` version.
41-
function Base.copyto!(
42-
dest::AbstractArray{<:Any,N}, bc::Broadcasted{BlockSparseArrayStyle{N}}
43-
) where {N}
51+
function copyto_blocksparse!(dest::AbstractArray, bc::Broadcasted)
4452
# convert to map
4553
# flatten and only keep the AbstractArray arguments
4654
m = Mapped(bc)
47-
@interface interface(bc) map!(m.f, dest, m.args...)
55+
@interface interface(dest, bc) map!(m.f, dest, m.args...)
56+
return dest
57+
end
58+
59+
function Base.copyto!(
60+
dest::AbstractArray{<:Any,N}, bc::Broadcasted{BlockSparseArrayStyle{N}}
61+
) where {N}
62+
copyto_blocksparse!(dest, bc)
4863
return dest
4964
end

0 commit comments

Comments
 (0)