Skip to content

Commit e033242

Browse files
authored
Remove special-casing of Ref in broadcast. (#510)
1 parent 4278412 commit e033242

File tree

2 files changed

+3
-27
lines changed

2 files changed

+3
-27
lines changed

src/host/broadcast.jl

Lines changed: 3 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -4,15 +4,6 @@ using Base.Broadcast
44

55
import Base.Broadcast: BroadcastStyle, Broadcasted, AbstractArrayStyle, instantiate
66

7-
const BroadcastGPUArray{T} = Union{AnyGPUArray{T},
8-
Base.RefValue{<:AbstractGPUArray{T}}}
9-
10-
# Ref is special: it's not a real wrapper, so not part of Adapt,
11-
# but it is commonly used to bypass broadcasting of an argument
12-
# so we need to preserve its dimensionless properties.
13-
BroadcastStyle(::Type{Base.RefValue{AT}}) where {AT<:AbstractGPUArray} =
14-
typeof(BroadcastStyle(AT))(Val(0))
15-
backend(::Type{Base.RefValue{AT}}) where {AT<:AbstractGPUArray} = backend(AT)
167
# but make sure we don't dispatch to the optimized copy method that directly indexes
178
function Broadcast.copy(bc::Broadcasted{<:AbstractGPUArrayStyle{0}})
189
ElType = Broadcast.combine_eltypes(bc.f, bc.args)
@@ -41,7 +32,7 @@ end
4132
return _copyto!(dest, instantiate(Broadcasted{Style}(bc.f, bc.args, axes(dest))))
4233
end
4334

44-
@inline Base.copyto!(dest::BroadcastGPUArray, bc::Broadcasted{Nothing}) = _copyto!(dest, bc) # Keep it for ArrayConflict
35+
@inline Base.copyto!(dest::AnyGPUArray, bc::Broadcasted{Nothing}) = _copyto!(dest, bc) # Keep it for ArrayConflict
4536

4637
@inline Base.copyto!(dest::AbstractArray, bc::Broadcasted{<:AbstractGPUArrayStyle}) = _copyto!(dest, bc)
4738

@@ -77,7 +68,7 @@ end
7768
allequal(x) = true
7869
allequal(x, y, z...) = x == y && allequal(y, z...)
7970

80-
function Base.map(f, x::BroadcastGPUArray, xs::AbstractArray...)
71+
function Base.map(f, x::AnyGPUArray, xs::AbstractArray...)
8172
# if argument sizes match, their shape needs to be preserved
8273
xs = (x, xs...)
8374
if allequal(size.(xs)...)
@@ -96,7 +87,7 @@ function Base.map(f, x::BroadcastGPUArray, xs::AbstractArray...)
9687
return map!(f, dest, xs...)
9788
end
9889

99-
function Base.map!(f, dest::BroadcastGPUArray, xs::AbstractArray...)
90+
function Base.map!(f, dest::AnyGPUArray, xs::AbstractArray...)
10091
# custom broadcast, ignoring the container size mismatches
10192
# (avoids the reshape + view that our mapreduce impl has to do)
10293
indices = LinearIndices.((dest, xs...))

test/testsuite/broadcasting.jl

Lines changed: 0 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -156,21 +156,6 @@ function broadcasting(AT, eltypes)
156156
map(+, x, y)
157157
end
158158
end
159-
160-
@testset "Ref" begin
161-
# as first arg, 0d broadcast
162-
@test compare(x->getindex.(Ref(x), 1), AT, ET[0])
163-
164-
void_setindex!(args...) = (setindex!(args...); return)
165-
@test compare(x->(void_setindex!.(Ref(x), ET(1)); x), AT, ET[0])
166-
167-
# regular broadcast
168-
a = AT(rand(ET, 10))
169-
b = AT(rand(ET, 10))
170-
cpy(i,a,b) = (a[i] = b[i]; return)
171-
cpy.(1:10, Ref(a), Ref(b))
172-
@test Array(a) == Array(b)
173-
end
174159
end
175160

176161
@testset "stackoverflow in copy(::Broadcast)" begin

0 commit comments

Comments
 (0)