Skip to content

Commit ad48a28

Browse files
committed
Construct BlockSparseArray when slicing with graded unit ranges
1 parent 37fecf7 commit ad48a28

File tree

4 files changed

+112
-7
lines changed

4 files changed

+112
-7
lines changed

Project.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@ LabelledNumbers = "f856a3a6-4152-4ec4-b2a7-02c1a55d7993"
2525
TensorAlgebra = "68bd88dc-f39d-4e12-b2ca-f046b68fcc6a"
2626

2727
[extensions]
28+
BlockSparseArraysGradedUnitRangesExt = "GradedUnitRanges"
2829
BlockSparseArraysTensorAlgebraExt = ["LabelledNumbers", "TensorAlgebra"]
2930

3031
[compat]
Lines changed: 44 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,44 @@
1+
module BlockSparseArraysGradedUnitRangesExt
2+
3+
using BlockSparseArrays: BlockSparseArray
4+
using GradedUnitRanges: AbstractGradedUnitRange
5+
6+
# A block spare array similar to the input (dense) array.
7+
# TODO: Make `BlockSparseArrays.blocksparse_similar` more general and use that,
8+
# and also turn it into an DerivableInterfaces.jl-based interface function.
9+
function similar_blocksparse(
10+
a::AbstractArray,
11+
elt::Type,
12+
axes::Tuple{AbstractGradedUnitRange,Vararg{AbstractGradedUnitRange}},
13+
)
14+
# TODO: Probably need to unwrap the type of `a` in certain cases
15+
# to make a proper block type.
16+
return BlockSparseArray{elt,length(axes),typeof(a)}(axes)
17+
end
18+
19+
function Base.similar(
20+
a::AbstractArray,
21+
elt::Type,
22+
axes::Tuple{AbstractGradedUnitRange,Vararg{AbstractGradedUnitRange}},
23+
)
24+
return similar_blocksparse(a, elt, axes)
25+
end
26+
27+
# Fix ambiguity error with `BlockArrays.jl`.
28+
function Base.similar(
29+
a::StridedArray,
30+
elt::Type,
31+
axes::Tuple{
32+
AbstractGradedUnitRange,AbstractGradedUnitRange,Vararg{AbstractGradedUnitRange}
33+
},
34+
)
35+
return similar_blocksparse(a, elt, axes)
36+
end
37+
38+
function Base.getindex(a::AbstractArray, I::AbstractGradedUnitRange...)
39+
a′ = similar(a, only.(axes.(I))...)
40+
a′ .= a
41+
return a′
42+
end
43+
44+
end

src/abstractblocksparsearray/map.jl

Lines changed: 26 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -69,11 +69,19 @@ end
6969
@interface interface::AbstractBlockSparseArrayInterface function Base.map!(
7070
f, a_dest::AbstractArray, a_srcs::AbstractArray...
7171
)
72+
if isempty(a_srcs)
73+
# Broadcast expressions of the form `a .= 2`.
74+
error("Not implemented.")
75+
end
7276
if iszero(ndims(a_dest))
7377
@interface interface map_zero_dim!(f, a_dest, a_srcs...)
7478
return a_dest
7579
end
7680

81+
# TODO: This assumes element types are numbers, generalize this logic.
82+
elt = promote_type(eltype.(a_srcs)...)
83+
f_preserves_zeros = f(zero(elt)) == zero(elt)
84+
7785
a_dest, a_srcs = reblock(a_dest), reblock.(a_srcs)
7886
for I in union_stored_blocked_cartesianindices(a_dest, a_srcs...)
7987
BI_dest = blockindexrange(a_dest, I)
@@ -88,11 +96,13 @@ end
8896
end
8997
subblock_dest = @view block_dest[BI_dest.indices...]
9098
subblock_srcs = ntuple(i -> @view(block_srcs[i][BI_srcs[i].indices...]), length(a_srcs))
91-
# TODO: Use `map!!` to handle immutable blocks.
92-
map!(f, subblock_dest, subblock_srcs...)
93-
# Replace the entire block, handles initializing new blocks
94-
# or if blocks are immutable.
95-
blocks(a_dest)[Int.(Tuple(_block(BI_dest)))...] = block_dest
99+
if f_preserves_zeros && any(!iszero, subblock_srcs)
100+
# TODO: Use `map!!` to handle immutable blocks.
101+
map!(f, subblock_dest, subblock_srcs...)
102+
# Replace the entire block, handles initializing new blocks
103+
# or if blocks are immutable.
104+
blocks(a_dest)[Int.(Tuple(_block(BI_dest)))...] = block_dest
105+
end
96106
end
97107
return a_dest
98108
end
@@ -120,7 +130,17 @@ end
120130
end
121131

122132
function Base.map!(f, a_dest::AbstractArray, a_srcs::AnyAbstractBlockSparseArray...)
123-
@interface interface(a_srcs...) map!(f, a_dest, a_srcs...)
133+
@interface interface(a_dest, a_srcs...) map!(f, a_dest, a_srcs...)
134+
return a_dest
135+
end
136+
function Base.map!(f, a_dest::AnyAbstractBlockSparseArray, a_srcs::AbstractArray...)
137+
@interface interface(a_dest, a_srcs...) map!(f, a_dest, a_srcs...)
138+
return a_dest
139+
end
140+
function Base.map!(
141+
f, a_dest::AnyAbstractBlockSparseArray, a_srcs::AnyAbstractBlockSparseArray...
142+
)
143+
@interface interface(a_dest, a_srcs...) map!(f, a_dest, a_srcs...)
124144
return a_dest
125145
end
126146

src/blocksparsearrayinterface/broadcast.jl

Lines changed: 41 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,46 @@ function Base.copyto!(
4444
# convert to map
4545
# flatten and only keep the AbstractArray arguments
4646
m = Mapped(bc)
47-
@interface interface(bc) map!(m.f, dest, m.args...)
47+
@interface interface(dest, bc) map!(m.f, dest, m.args...)
48+
return dest
49+
end
50+
51+
# Broadcasting implementation
52+
# TODO: Delete this in favor of `DerivableInterfaces` version.
53+
function Base.copyto!(dest::AnyAbstractBlockSparseArray, bc::Broadcasted)
54+
# convert to map
55+
# flatten and only keep the AbstractArray arguments
56+
m = Mapped(bc)
57+
# TODO: Include `bc` when determining interface, currently
58+
# `interface(::Type{<:Base.Broadcast.DefaultArrayStyle})`
59+
# isn't defined.
60+
@interface interface(dest) map!(m.f, dest, m.args...)
61+
return dest
62+
end
63+
64+
# Broadcasting implementation
65+
# TODO: Delete this in favor of `DerivableInterfaces` version.
66+
function Base.copyto!(
67+
dest::AnyAbstractBlockSparseArray, bc::Broadcasted{<:Base.Broadcast.AbstractArrayStyle{0}}
68+
)
69+
# convert to map
70+
# flatten and only keep the AbstractArray arguments
71+
m = Mapped(bc)
72+
# TODO: Include `bc` when determining interface, currently
73+
# `interface(::Type{<:Base.Broadcast.DefaultArrayStyle})`
74+
# isn't defined.
75+
@interface interface(dest) map!(m.f, dest, m.args...)
76+
return dest
77+
end
78+
79+
# Broadcasting implementation
80+
# TODO: Delete this in favor of `DerivableInterfaces` version.
81+
function Base.copyto!(
82+
dest::AnyAbstractBlockSparseArray{<:Any,N}, bc::Broadcasted{BlockSparseArrayStyle{N}}
83+
) where {N}
84+
# convert to map
85+
# flatten and only keep the AbstractArray arguments
86+
m = Mapped(bc)
87+
@interface interface(dest, bc) map!(m.f, dest, m.args...)
4888
return dest
4989
end

0 commit comments

Comments
 (0)