|
26 | 26 | # This Union is a hack. Ideally Base would have a Transpose <: WrappedArray <: AbstractArray
|
27 | 27 | # and we could define our methods in terms of Union{AbstractGPUArray, WrappedArray{<:Any, <:AbstractGPUArray}}
|
28 | 28 | @eval const GPUDestArray =
|
29 |
| - Union{AbstractGPUArray, $((:($W where {AT <: AbstractGPUArray}) for (W, _) in Adapt.wrappers)...)} |
| 29 | + Union{AbstractGPUArray, |
| 30 | + $((:($W where {AT <: AbstractGPUArray}) for (W, _) in Adapt.wrappers)...), |
| 31 | + Base.RefValue{<:AbstractGPUArray} } |
| 32 | + |
| 33 | +# Ref is special: it's not a real wrapper, so not part of Adapt, |
| 34 | +# but it is commonly used to bypass broadcasting of an argument |
| 35 | +# so we need to preserve its dimensionless properties. |
| 36 | +BroadcastStyle(::Type{Base.RefValue{AT}}) where {AT<:AbstractGPUArray} = typeof(BroadcastStyle(AT))(Val(0)) |
| 37 | +backend(::Type{Base.RefValue{AT}}) where {AT<:AbstractGPUArray} = backend(AT) |
| 38 | +# but make sure we don't dispatch to the optimized copy method that directly indexes |
| 39 | +function Broadcast.copy(bc::Broadcasted{<:AbstractGPUArrayStyle{0}}) |
| 40 | + ElType = Broadcast.combine_eltypes(bc.f, bc.args) |
| 41 | + isbitstype(ElType) || error("Cannot broadcast function returning non-isbits $ElType.") |
| 42 | + dest = copyto!(similar(bc, ElType), bc) |
| 43 | + return @allowscalar dest[CartesianIndex()] # 0D broadcast needs to unwrap results |
| 44 | +end |
30 | 45 |
|
31 | 46 | # We purposefully only specialize `copyto!`, dependent packages need to make sure that they
|
32 | 47 | # can handle:
|
|
47 | 62 | bc′ = Broadcast.preprocess(dest, bc)
|
48 | 63 | gpu_call(dest, bc′) do ctx, dest, bc′
|
49 | 64 | let I = CartesianIndex(@cartesianidx(dest))
|
50 |
| - @inbounds dest[I] = bc′[I] |
| 65 | + #@inbounds dest[I] = bc′[I] |
| 66 | + @inbounds let |
| 67 | + val = bc′[I] |
| 68 | + if val !== nothing |
| 69 | + # FIXME: CuArrays.jl crashes on assigning Nothing (this happens with |
| 70 | + # broadcasts that don't return anything but assign anyway) |
| 71 | + dest[I] = val |
| 72 | + end |
| 73 | + end |
51 | 74 | end
|
52 | 75 | return
|
53 | 76 | end
|
|
0 commit comments