Skip to content

Commit c830dfd

Browse files
tkfdlfivefifty
authored andcommitted
Optimizing the block-style broadcasting (#66)
* Optimize block array broadcast * Define unsafe_view for BlockIndexRange * Flatten Broadcasted * Allow `ones(2)[Block(1)[1:1], Block(1)[1:1]]` * Disambiguate unsafe_view(::SubArray, ::BlockSlice...) * Use BlockSlice's indices in unsafe_view * Restrict an unsafe_view definition to BlockArray
1 parent cce57b6 commit c830dfd

File tree

4 files changed

+124
-11
lines changed

4 files changed

+124
-11
lines changed

src/BlockArrays.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@ import Base: @propagate_inbounds, Array, to_indices, to_index,
2323
RangeIndex, Int, Integer, Number,
2424
+, -, min, max, *, isless, in, copy, copyto!, axes, @deprecate,
2525
BroadcastStyle
26-
using Base: dataids
26+
using Base: ReshapedArray, dataids
2727

2828

2929

src/blockbroadcast.jl

Lines changed: 76 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -25,11 +25,6 @@ BroadcastStyle(::BlockStyle{M}, ::PseudoBlockStyle{N}) where {M,N} = BlockStyle(
2525
BroadcastStyle(::PseudoBlockStyle{M}, ::BlockStyle{N}) where {M,N} = BlockStyle(Val(max(M,N)))
2626

2727

28-
####
29-
# Default to standard Array broadcast
30-
####
31-
32-
3328
# following code modified from julia/base/broadcast.jl
3429
broadcast_cumulsizes(::Number) = ()
3530
broadcast_cumulsizes(A::AbstractArray) = cumulsizes(blocksizes(A))
@@ -48,11 +43,84 @@ blocksizes(A::Broadcasted{<:AbstractArrayStyle{N}}) where N =
4843
BlockSizes(combine_cumulsizes(broadcast_cumulsizes.(A.args)...))
4944

5045

51-
copyto!(dest::AbstractArray, bc::Broadcasted{<:AbstractBlockStyle{N}}) where N =
52-
copyto!(dest, Broadcasted{DefaultArrayStyle{N}}(bc.f, bc.args, bc.axes))
53-
5446
similar(bc::Broadcasted{<:AbstractBlockStyle{N}}, ::Type{T}) where {T,N} =
5547
BlockArray{T,N}(undef, blocksizes(bc))
5648

5749
similar(bc::Broadcasted{PseudoBlockStyle{N}}, ::Type{T}) where {T,N} =
5850
PseudoBlockArray{T,N}(undef, blocksizes(bc))
51+
52+
53+
subblocks(::Any, bs::BlockSizes, dim::Integer) =
54+
(nothing for _ in 1:nblocks(bs, dim))
55+
56+
function subblocks(arr::AbstractArray, bs::BlockSizes, dim::Integer)
57+
if size(arr, dim) == 1
58+
return (BlockIndexRange(Block(1), 1:1) for _ in 1:nblocks(bs, dim))
59+
end
60+
j = 1
61+
next = 1
62+
arrstops = cumulsizes(arr, dim)
63+
return (
64+
let n = blocksize(bs, dim, i)
65+
start = next
66+
next = start + n
67+
j0 = j
68+
if arrstops[j + 1] == next
69+
j += 1
70+
end
71+
BlockIndexRange(Block(j0), (start:next - 1) .- (arrstops[j0] - 1))
72+
end
73+
for i in 1:nblocks(bs, dim))
74+
end
75+
76+
@inline _bview(arg, ::Vararg) = arg
77+
@inline _bview(A::AbstractArray, I...) = view(A, I...)
78+
79+
@generated function copyto!(
80+
dest::AbstractArray,
81+
bc::Broadcasted{<:AbstractBlockStyle{NDims}, <:Any, <:Any, Args},
82+
) where {NDims, Args <: Tuple}
83+
84+
NArgs = length(Args.parameters)
85+
86+
# `bvar(0, dim)` is a variable for BlockIndexRange of `dim`-th dimension
87+
# of `dest` array. `bvar(i, dim)` is a similar variable of `i`-th
88+
# argument in `bc.args`.
89+
bvar(i, dim) = Symbol("blockindexrange_", i, "_", dim)
90+
91+
function forloop(dim)
92+
if dim > 0
93+
quote
94+
for ($(bvar(0, dim)), $(bvar.(1:NArgs, dim)...),) in zip(
95+
subblocks(dest, bs, $dim),
96+
subblocks.(bc.args, Ref(bs), Ref($dim))...)
97+
$(forloop(dim - 1))
98+
end
99+
end
100+
else
101+
bview(a, i) = :(_bview($a, $([bvar(i, d) for d in 1:NDims]...)))
102+
destview = bview(:dest, 0)
103+
argblocks = [bview(:(bc.args[$i]), i) for i in 1:NArgs]
104+
quote
105+
broadcast!(bc.f, $destview, $(argblocks...))
106+
end
107+
end
108+
end
109+
110+
quote
111+
bs = blocksizes(bc)
112+
if blocksizes(dest) bs
113+
copyto!(PseudoBlockArray(dest, bs), bc)
114+
return dest
115+
end
116+
117+
$(forloop(NDims))
118+
return dest
119+
end
120+
end
121+
122+
@inline function Broadcast.instantiate(
123+
bc::Broadcasted{Style}) where {Style <:AbstractBlockStyle}
124+
bcf = Broadcast.flatten(Broadcasted{Nothing}(bc.f, bc.args, bc.axes))
125+
return Broadcasted{Style}(bcf.f, bcf.args, bcf.axes)
126+
end

src/blockindexrange.jl

Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -79,7 +79,45 @@ reindex(V, idxs::Tuple{BlockSlice{<:BlockRange}, Vararg{Any}},
7979
idxs[1].indices[subidxs[1].indices]),
8080
reindex(V, tail(idxs), tail(subidxs))...))
8181

82+
# De-reference blocks before creating a view to avoid taking `global2blockindex`
83+
# path in `AbstractBlockStyle` broadcasting.
84+
@inline function Base.unsafe_view(
85+
A::BlockArray{<:Any, N},
86+
I::Vararg{BlockSlice{<:BlockIndexRange{1}}, N}) where {N}
87+
@_propagate_inbounds_meta
88+
B = A[map(x -> x.block.block, I)...]
89+
return view(B, _splatmap(x -> x.block.indices, I)...)
90+
end
91+
92+
@inline function Base.unsafe_view(
93+
A::PseudoBlockArray{<:Any, N},
94+
I::Vararg{BlockSlice{<:BlockIndexRange{1}}, N}) where {N}
95+
@_propagate_inbounds_meta
96+
return view(A.blocks, map(x -> x.indices, I)...)
97+
end
98+
99+
@inline function Base.unsafe_view(
100+
A::ReshapedArray{<:Any, N, <:AbstractBlockArray{<:Any, M}},
101+
I::Vararg{BlockSlice{<:BlockIndexRange{1}}, N}) where {N, M}
102+
@_propagate_inbounds_meta
103+
# Note: assuming that I[M+1:end] are verified to be singletons
104+
return reshape(view(A.parent, I[1:M]...), Val(N))
105+
end
106+
107+
@inline function Base.unsafe_view(
108+
A::AbstractArray{<:Any, N},
109+
I::Vararg{BlockSlice{<:BlockIndexRange{1}}, N}) where {N}
110+
@_propagate_inbounds_meta
111+
return view(A, map(x -> x.indices, I)...)
112+
end
82113

114+
# Disambiguation
115+
@inline function Base.unsafe_view(
116+
A::SubArray,
117+
I::Vararg{BlockSlice{<:BlockIndexRange{1}}, N}) where {N}
118+
@_propagate_inbounds_meta
119+
return view(A, map(x -> x.indices, I)...)
120+
end
83121

84122

85123
# #################

src/views.jl

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -71,8 +71,15 @@ blocksizes(V::SubArray) = BlockSizes(_sub_cumul_sizes(cumulsizes(parent(V)), par
7171
7272
Returns the indices associated with a block as a `BlockSlice`.
7373
"""
74-
unblock(A::AbstractArray{T,N}, inds, I) where {T, N} = _unblock(cumulsizes(A, N - length(inds) + 1), I)
75-
74+
function unblock(A::AbstractArray{T,N}, inds, I) where {T, N}
75+
if length(inds) == 0
76+
# Allow `ones(2)[Block(1)[1:1], Block(1)[1:1]]` which is
77+
# similar to `ones(2)[1:1, 1:1]`.
78+
_unblock(Base.OneTo(2), I)
79+
else
80+
_unblock(cumulsizes(A, N - length(inds) + 1), I)
81+
end
82+
end
7683

7784

7885
to_index(::Block) = throw(ArgumentError("Block must be converted by to_indices(...)"))

0 commit comments

Comments
 (0)