Skip to content

Commit 3b0102b

Browse files
committed
Process Fill broadcasting before others
1 parent 004b173 commit 3b0102b

File tree

2 files changed

+19
-6
lines changed

2 files changed

+19
-6
lines changed

src/fillbroadcast.jl

Lines changed: 14 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -118,23 +118,31 @@ _isfill(f::Number) = true
118118
_isfill(f::Ref) = true
119119
_isfill(::Any) = false
120120

121-
function Base.copy(bc::Broadcast.Broadcasted{<:AbstractFillStyle{N}}) where {N}
121+
_broadcast_maybecopy(bc::Broadcast.Broadcasted{<:AbstractFillStyle}) = copy(bc)
122+
_broadcast_maybecopy(x) = x
123+
124+
function _fallback_copy(bc)
125+
# treat the fill components
126+
bc2 = Base.broadcasted(bc.f, map(_broadcast_maybecopy, bc.args)...)
127+
# fallback style
128+
S = Broadcast.Broadcasted{Broadcast.DefaultArrayStyle{ndims(bc)}}
129+
copy(convert(S, bc2))
130+
end
131+
132+
function Base.copy(bc::Broadcast.Broadcasted{<:AbstractFillStyle})
122133
if _iszeros(bc)
123134
return Zeros(typeof(_getindex_value(bc)), axes(bc))
124135
elseif _isones(bc)
125136
return Ones(typeof(_getindex_value(bc)), axes(bc))
126137
elseif _isfill(bc)
127138
return Fill(_getindex_value(bc), axes(bc))
128139
else
129-
# fallback style
130-
S = Broadcast.Broadcasted{Broadcast.DefaultArrayStyle{N}}
131-
copy(convert(S, bc))
140+
_fallback_copy(bc)
132141
end
133142
end
134143
# make the zero-dimensional case consistent with Base
135144
function Base.copy(bc::Broadcast.Broadcasted{<:AbstractFillStyle{0}})
136-
S = Broadcast.Broadcasted{Broadcast.DefaultArrayStyle{0}}
137-
copy(convert(S, bc))
145+
_fallback_copy(bc)
138146
end
139147

140148
# some cases that preserve 0d

test/runtests.jl

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1239,6 +1239,11 @@ end
12391239
F = Fill(1, 2)
12401240
@test g.(F, "a") === f.(F)
12411241
end
1242+
1243+
@testset "early binding" begin
1244+
A = ones(2) .+ (x -> rand()).(Fill(2,2))
1245+
@test all(==(A[1]), A)
1246+
end
12421247
end
12431248

12441249
@testset "map" begin

0 commit comments

Comments
 (0)