Skip to content

Commit 7807069

Browse files
committed
Handle broadcasting Ref.
1 parent b16a84b commit 7807069

File tree

3 files changed

+41
-2
lines changed

3 files changed

+41
-2
lines changed

src/host/broadcast.jl

Lines changed: 25 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,22 @@ end
2626
# This Union is a hack. Ideally Base would have a Transpose <: WrappedArray <: AbstractArray
2727
# and we could define our methods in terms of Union{AbstractGPUArray, WrappedArray{<:Any, <:AbstractGPUArray}}
2828
@eval const GPUDestArray =
29-
Union{AbstractGPUArray, $((:($W where {AT <: AbstractGPUArray}) for (W, _) in Adapt.wrappers)...)}
29+
Union{AbstractGPUArray,
30+
$((:($W where {AT <: AbstractGPUArray}) for (W, _) in Adapt.wrappers)...),
31+
Base.RefValue{<:AbstractGPUArray} }
32+
33+
# Ref is special: it's not a real wrapper, so not part of Adapt,
34+
# but it is commonly used to bypass broadcasting of an argument
35+
# so we need to preserve its dimensionless properties.
36+
BroadcastStyle(::Type{Base.RefValue{AT}}) where {AT<:AbstractGPUArray} = typeof(BroadcastStyle(AT))(Val(0))
37+
backend(::Type{Base.RefValue{AT}}) where {AT<:AbstractGPUArray} = backend(AT)
38+
# but make sure we don't dispatch to the optimized copy method that directly indexes
39+
function Broadcast.copy(bc::Broadcasted{<:AbstractGPUArrayStyle{0}})
40+
ElType = Broadcast.combine_eltypes(bc.f, bc.args)
41+
isbitstype(ElType) || error("Cannot broadcast function returning non-isbits $ElType.")
42+
dest = copyto!(similar(bc, ElType), bc)
43+
return @allowscalar dest[CartesianIndex()] # 0D broadcast needs to unwrap results
44+
end
3045

3146
# We purposefully only specialize `copyto!`, dependent packages need to make sure that they
3247
# can handle:
@@ -47,7 +62,15 @@ end
4762
bc′ = Broadcast.preprocess(dest, bc)
4863
gpu_call(dest, bc′) do ctx, dest, bc′
4964
let I = CartesianIndex(@cartesianidx(dest))
50-
@inbounds dest[I] = bc′[I]
65+
#@inbounds dest[I] = bc′[I]
66+
@inbounds let
67+
val = bc′[I]
68+
if val !== nothing
69+
# FIXME: CuArrays.jl crashes on assigning Nothing (this happens with
70+
# broadcasts that don't return anything but assign anyway)
71+
dest[I] = val
72+
end
73+
end
5174
end
5275
return
5376
end

src/reference.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -147,6 +147,7 @@ struct JLArray{T, N} <: AbstractGPUArray{T, N}
147147
dims::Dims{N}
148148

149149
function JLArray{T,N}(data::Array{T, N}, dims::Dims{N}) where {T,N}
150+
@assert isbitstype(T) "JLArray only supports bits types"
150151
new(data, dims)
151152
end
152153
end

test/testsuite/broadcasting.jl

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -126,6 +126,21 @@ function broadcasting(AT)
126126
x /= 2
127127
@test collect(x)[] == 0.5
128128
end
129+
130+
@testset "Ref" begin
131+
# as first arg, 0d broadcast
132+
@test compare(x->getindex.(Ref(x),1), AT, [0])
133+
134+
void_setindex!(args...) = (setindex!(args...); return)
135+
@test compare(x->(void_setindex!.(Ref(x),1); x), AT, [0])
136+
137+
# regular broadcast
138+
a = AT(rand(10))
139+
b = AT(rand(10))
140+
cpy(i,a,b) = (a[i] = b[i]; return)
141+
cpy.(1:10, Ref(a), Ref(b))
142+
@test Array(a) == Array(b)
143+
end
129144
end
130145

131146
function vec3(AT)

0 commit comments

Comments
 (0)