Skip to content

Commit c2b5da2

Browse files
Merge pull request #136 from jdeldre/je/broadcast-similar
Preserve custom AbstractArray types during broadcasting of ArrayPartition
2 parents 734999e + 056acaf commit c2b5da2

File tree

2 files changed

+39
-2
lines changed

2 files changed

+39
-2
lines changed

src/array_partition.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -304,8 +304,8 @@ _npartitions(args::Tuple{Any}) = npartitions(args[1])
304304
_npartitions(args::Tuple{}) = 0
305305

306306
# drop axes because it is easier to recompute
307-
@inline unpack(bc::Broadcast.Broadcasted{Style}, i) where Style = Broadcast.Broadcasted{Style}(bc.f, unpack_args(i, bc.args))
308-
@inline unpack(bc::Broadcast.Broadcasted{ArrayPartitionStyle{Style}}, i) where Style = Broadcast.Broadcasted{Style}(bc.f, unpack_args(i, bc.args))
307+
@inline unpack(bc::Broadcast.Broadcasted{Style}, i) where Style = Broadcast.Broadcasted(bc.f, unpack_args(i, bc.args))
308+
@inline unpack(bc::Broadcast.Broadcasted{ArrayPartitionStyle{Style}}, i) where Style = Broadcast.Broadcasted(bc.f, unpack_args(i, bc.args))
309309
unpack(x,::Any) = x
310310
unpack(x::ArrayPartition, i) = x.x[i]
311311

test/partitions_test.jl

Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -136,6 +136,43 @@ end
136136
foo(xcde0, xce0)
137137
#@test 0 == @allocated foo(xcde0, xce0)
138138

139+
# Custom AbstractArray types broadcasting
140+
struct MyType{T} <: AbstractVector{T}
141+
data :: Vector{T}
142+
end
143+
Base.similar(A::MyType{T}) where {T} = MyType{T}(similar(A.data))
144+
Base.similar(A::MyType{T},::Type{S}) where {T,S} = MyType(similar(A.data,S))
145+
146+
Base.size(A::MyType) = size(A.data)
147+
Base.getindex(A::MyType, i::Int) = getindex(A.data,i)
148+
Base.setindex!(A::MyType, v, i::Int) = setindex!(A.data,v,i)
149+
Base.IndexStyle(::MyType) = IndexLinear()
150+
151+
Base.BroadcastStyle(::Type{<:MyType}) = Broadcast.ArrayStyle{MyType}()
152+
153+
function Base.similar(bc::Broadcast.Broadcasted{Broadcast.ArrayStyle{MyType}},::Type{T}) where {T}
154+
similar(find_mt(bc),T)
155+
end
156+
157+
function Base.similar(bc::Broadcast.Broadcasted{Broadcast.ArrayStyle{MyType}})
158+
similar(find_mt(bc))
159+
end
160+
161+
find_mt(bc::Base.Broadcast.Broadcasted) = find_mt(bc.args)
162+
find_mt(args::Tuple) = find_mt(find_mt(args[1]), Base.tail(args))
163+
find_mt(x) = x
164+
find_mt(::Tuple{}) = nothing
165+
find_mt(a::MyType, rest) = a
166+
find_mt(::Any, rest) = find_mt(rest)
167+
168+
ap = ArrayPartition(MyType(ones(10)),collect(1:2))
169+
up = ap .+ 1
170+
@test typeof(ap) == typeof(up)
171+
172+
up = 2 .* ap .+ 1
173+
@test typeof(ap) == typeof(up)
174+
175+
139176
@testset "ArrayInterface.ismutable(ArrayPartition($a, $b)) == $r" for (a, b, r) in ((1,2, false), ([1], 2, false), ([1], [2], true))
140177
@test ArrayInterface.ismutable(ArrayPartition(a, b)) == r
141178
end

0 commit comments

Comments
 (0)