Skip to content

Commit cc7bb30

Browse files
authored
Merge pull request #240 from JuliaGPU/tb/painful_changes
Fixes to broadcast and indexing
2 parents cf699ef + 7807069 commit cc7bb30

File tree

7 files changed

+101
-45
lines changed

7 files changed

+101
-45
lines changed

src/host/broadcast.jl

Lines changed: 36 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -1,23 +1,21 @@
11
# broadcasting operations
22

3+
export AbstractGPUArrayStyle
4+
35
using Base.Broadcast
46

5-
import Base.Broadcast: BroadcastStyle, Broadcasted, ArrayStyle
7+
import Base.Broadcast: BroadcastStyle, Broadcasted, AbstractArrayStyle
68

7-
# we define a generic `BroadcastStyle` here that should be sufficient for most cases.
8-
# dependent packages like `CuArrays` can define their own `BroadcastStyle` allowing
9-
# them to further change or optimize broadcasting.
10-
#
11-
# TODO: investigate if we should define out own `GPUArrayStyle{N} <: AbstractArrayStyle{N}`
12-
#
13-
# NOTE: this uses the specific `T` that was used e.g. `JLArray` or `CLArray` for ArrayStyle,
14-
# instead of using `ArrayStyle{AbstractGPUArray}`, due to the fact how `similar` works.
15-
BroadcastStyle(::Type{T}) where {T<:AbstractGPUArray} = ArrayStyle{T}()
9+
"""
10+
Abstract supertype for GPU array styles. The `N` parameter is the dimensionality.
11+
12+
Downstream implementations should provide a concrete array style type that inherits from
13+
this supertype.
14+
"""
15+
abstract type AbstractGPUArrayStyle{N} <: AbstractArrayStyle{N} end
1616

1717
# Wrapper types otherwise forget that they are GPU compatible
18-
#
19-
# NOTE: Don't directly use ArrayStyle{AbstractGPUArray} here since that would mean that `CuArrays`
20-
# customization no longer take effect.
18+
# NOTE: don't directly use GPUArrayStyle here not to lose downstream customizations.
2119
for (W, ctor) in Adapt.wrappers
2220
@eval begin
2321
BroadcastStyle(::Type{<:$W}) where {AT<:AbstractGPUArray} = BroadcastStyle(AT)
@@ -28,7 +26,22 @@ end
2826
# This Union is a hack. Ideally Base would have a Transpose <: WrappedArray <: AbstractArray
2927
# and we could define our methods in terms of Union{AbstractGPUArray, WrappedArray{<:Any, <:AbstractGPUArray}}
3028
@eval const GPUDestArray =
31-
Union{AbstractGPUArray, $((:($W where {AT <: AbstractGPUArray}) for (W, _) in Adapt.wrappers)...)}
29+
Union{AbstractGPUArray,
30+
$((:($W where {AT <: AbstractGPUArray}) for (W, _) in Adapt.wrappers)...),
31+
Base.RefValue{<:AbstractGPUArray} }
32+
33+
# Ref is special: it's not a real wrapper, so not part of Adapt,
34+
# but it is commonly used to bypass broadcasting of an argument
35+
# so we need to preserve its dimensionless properties.
36+
BroadcastStyle(::Type{Base.RefValue{AT}}) where {AT<:AbstractGPUArray} = typeof(BroadcastStyle(AT))(Val(0))
37+
backend(::Type{Base.RefValue{AT}}) where {AT<:AbstractGPUArray} = backend(AT)
38+
# but make sure we don't dispatch to the optimized copy method that directly indexes
39+
function Broadcast.copy(bc::Broadcasted{<:AbstractGPUArrayStyle{0}})
40+
ElType = Broadcast.combine_eltypes(bc.f, bc.args)
41+
isbitstype(ElType) || error("Cannot broadcast function returning non-isbits $ElType.")
42+
dest = copyto!(similar(bc, ElType), bc)
43+
return @allowscalar dest[CartesianIndex()] # 0D broadcast needs to unwrap results
44+
end
3245

3346
# We purposefully only specialize `copyto!`, dependent packages need to make sure that they
3447
# can handle:
@@ -49,7 +62,15 @@ end
4962
bc′ = Broadcast.preprocess(dest, bc)
5063
gpu_call(dest, bc′) do ctx, dest, bc′
5164
let I = CartesianIndex(@cartesianidx(dest))
52-
@inbounds dest[I] = bc′[I]
65+
#@inbounds dest[I] = bc′[I]
66+
@inbounds let
67+
val = bc′[I]
68+
if val !== nothing
69+
# FIXME: CuArrays.jl crashes on assigning Nothing (this happens with
70+
# broadcasts that don't return anything but assign anyway)
71+
dest[I] = val
72+
end
73+
end
5374
end
5475
return
5576
end

src/host/indexing.jl

Lines changed: 6 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
# host-level indexing
22

3-
export allowscalar, @allowscalar, assertscalar
3+
export allowscalar, @allowscalar, @disallowscalar, assertscalar
44

55

66
# mechanism to disallow scalar operations
@@ -82,26 +82,18 @@ end
8282

8383
Base.IndexStyle(::Type{<:AbstractGPUArray}) = Base.IndexLinear()
8484

85-
function _getindex(xs::AbstractGPUArray{T}, i::Integer) where T
85+
function Base.getindex(xs::AbstractGPUArray{T}, i::Integer) where T
86+
ndims(xs) > 0 && assertscalar("scalar getindex")
8687
x = Array{T}(undef, 1)
8788
copyto!(x, 1, xs, i, 1)
8889
return x[1]
8990
end
9091

91-
function Base.getindex(xs::AbstractGPUArray{T}, i::Integer) where T
92-
ndims(xs) > 0 && assertscalar("scalar getindex")
93-
_getindex(xs, i)
94-
end
95-
96-
function _setindex!(xs::AbstractGPUArray{T}, v::T, i::Integer) where T
97-
x = T[v]
98-
copyto!(xs, i, x, 1, 1)
99-
return v
100-
end
101-
10292
function Base.setindex!(xs::AbstractGPUArray{T}, v::T, i::Integer) where T
10393
assertscalar("scalar setindex!")
104-
_setindex!(xs, v, i)
94+
x = T[v]
95+
copyto!(xs, i, x, 1, 1)
96+
return xs
10597
end
10698

10799
Base.setindex!(xs::AbstractGPUArray, v, i::Integer) = xs[i] = convert(eltype(xs), v)

src/host/mapreduce.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -57,8 +57,8 @@ gpu_promote_type(::typeof(min), ::Type{T}) where {T<: WidenReduceResult} = T
5757
gpu_promote_type(::typeof(abs), ::Type{Complex{T}}) where {T} = T
5858
gpu_promote_type(::typeof(abs2), ::Type{Complex{T}}) where {T} = T
5959

60-
import Base.Broadcast: Broadcasted, ArrayStyle
61-
const GPUSrcArray = Union{Broadcasted{ArrayStyle{AT}}, AbstractGPUArray{T, N}} where {T, N, AT<:AbstractGPUArray}
60+
import Base.Broadcast: Broadcasted
61+
const GPUSrcArray = Union{Broadcasted{<:AbstractGPUArrayStyle}, <:AbstractGPUArray}
6262

6363
function Base.mapreduce(f::Function, op::Function, A::GPUSrcArray; dims = :, init...)
6464
mapreduce_impl(f, op, init.data, A, dims)

src/host/quirks.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,6 @@ if VERSION >= v"1.3.0-alpha.107"
2222
@inline combine_axes(A, B...) = broadcast_shape(axes(A), combine_axes(B...))
2323
combine_axes(A) = axes(A)
2424

25-
Broadcast._axes(::Broadcasted{ArrayStyle{AT}}, axes::Tuple) where {AT <: AbstractGPUArray} = axes
26-
@inline Broadcast._axes(bc::Broadcasted{ArrayStyle{AT}}, ::Nothing) where {AT <: AbstractGPUArray} = combine_axes(bc.args...)
25+
Broadcast._axes(::Broadcasted{<:AbstractGPUArrayStyle}, axes::Tuple) = axes
26+
@inline Broadcast._axes(bc::Broadcasted{<:AbstractGPUArrayStyle}, ::Nothing) = combine_axes(bc.args...)
2727
end

src/reference.jl

Lines changed: 26 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -5,9 +5,11 @@
55

66
module JLArrays
77

8+
export JLArray
9+
810
using GPUArrays
911

10-
export JLArray
12+
using Adapt
1113

1214

1315
#
@@ -52,14 +54,21 @@ function JLKernelContext(ctx::JLKernelContext, threadidx::Int)
5254
)
5355
end
5456

55-
to_device(ctx, x::Tuple) = to_device.(Ref(ctx), x)
56-
to_device(ctx, x) = x
57+
struct Adaptor end
58+
jlconvert(arg) = adapt(Adaptor(), arg)
59+
60+
# FIXME: add Ref to Adapt.jl (but make sure it doesn't cause ambiguities with CUDAnative's)
61+
struct JlRefValue{T} <: Ref{T}
62+
x::T
63+
end
64+
Base.getindex(r::JlRefValue) = r.x
65+
Adapt.adapt_structure(to::Adaptor, r::Base.RefValue) = JlRefValue(adapt(to, r[]))
5766

5867
function GPUArrays.gpu_call(::JLBackend, f, args...; blocks::Int, threads::Int)
5968
ctx = JLKernelContext(threads, blocks)
60-
device_args = to_device.(Ref(ctx), args)
69+
device_args = jlconvert.(args)
6170
tasks = Array{Task}(undef, threads)
62-
@allowscalar for blockidx in 1:blocks
71+
@disallowscalar for blockidx in 1:blocks
6372
ctx.blockidx = blockidx
6473
for threadidx in 1:threads
6574
thread_ctx = JLKernelContext(ctx, threadidx)
@@ -138,6 +147,7 @@ struct JLArray{T, N} <: AbstractGPUArray{T, N}
138147
dims::Dims{N}
139148

140149
function JLArray{T,N}(data::Array{T, N}, dims::Dims{N}) where {T,N}
150+
@assert isbitstype(T) "JLArray only supports bits types"
141151
new(data, dims)
142152
end
143153
end
@@ -193,15 +203,19 @@ Base.convert(::Type{T}, x::T) where T <: JLArray = x
193203

194204
## broadcast
195205

196-
using Base.Broadcast: BroadcastStyle, Broadcasted, ArrayStyle
206+
using Base.Broadcast: BroadcastStyle, Broadcasted
207+
208+
struct JLArrayStyle{N} <: AbstractGPUArrayStyle{N} end
209+
JLArrayStyle(::Val{N}) where N = JLArrayStyle{N}()
210+
JLArrayStyle{M}(::Val{N}) where {N,M} = JLArrayStyle{N}()
197211

198-
BroadcastStyle(::Type{<:JLArray}) = ArrayStyle{JLArray}()
212+
BroadcastStyle(::Type{JLArray{T,N}}) where {T,N} = JLArrayStyle{N}()
199213

200-
function Base.similar(bc::Broadcasted{ArrayStyle{JLArray}}, ::Type{T}) where T
214+
Base.similar(bc::Broadcasted{JLArrayStyle{N}}, ::Type{T}) where {N,T} =
201215
similar(JLArray{T}, axes(bc))
202-
end
203216

204-
Base.similar(bc::Broadcasted{ArrayStyle{JLArray}}, ::Type{T}, dims...) where {T} = JLArray{T}(undef, dims...)
217+
Base.similar(bc::Broadcasted{JLArrayStyle{N}}, ::Type{T}, dims...) where {N,T} =
218+
JLArray{T}(undef, dims...)
205219

206220

207221
## memory operations
@@ -263,8 +277,8 @@ GPUArrays.device(x::JLArray) = JLDevice()
263277

264278
GPUArrays.backend(::Type{<:JLArray}) = JLBackend()
265279

266-
to_device(ctx, x::JLArray{T,N}) where {T,N} = JLDeviceArray{T,N}(x.data, x.dims)
267-
to_device(ctx, x::Base.RefValue{<: JLArray}) = Base.RefValue(to_device(ctx, x[]))
280+
Adapt.adapt_storage(::Adaptor, x::JLArray{T,N}) where {T,N} =
281+
JLDeviceArray{T,N}(x.data, x.dims)
268282

269283
GPUArrays.unsafe_reinterpret(::Type{T}, A::JLArray, size::Tuple) where T =
270284
reshape(reinterpret(T, A.data), size)

test/testsuite/broadcasting.jl

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -118,6 +118,29 @@ function broadcasting(AT)
118118
@test compare((A, B) -> A .* B .+ ET(10), AT, rand(ET, 40, 40), rand(ET, 40, 40))
119119
end
120120
end
121+
122+
@testset "0D" begin
123+
x = AT{Float64}(undef)
124+
x .= 1
125+
@test collect(x)[] == 1
126+
x /= 2
127+
@test collect(x)[] == 0.5
128+
end
129+
130+
@testset "Ref" begin
131+
# as first arg, 0d broadcast
132+
@test compare(x->getindex.(Ref(x),1), AT, [0])
133+
134+
void_setindex!(args...) = (setindex!(args...); return)
135+
@test compare(x->(void_setindex!.(Ref(x),1); x), AT, [0])
136+
137+
# regular broadcast
138+
a = AT(rand(10))
139+
b = AT(rand(10))
140+
cpy(i,a,b) = (a[i] = b[i]; return)
141+
cpy.(1:10, Ref(a), Ref(b))
142+
@test Array(a) == Array(b)
143+
end
121144
end
122145

123146
function vec3(AT)

test/testsuite/indexing.jl

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -62,5 +62,11 @@ function test_indexing(AT)
6262
@test Array(A) == Ac
6363
end
6464
end
65+
66+
@testset "get/setindex!" begin
67+
# literal calls to get/setindex! have differen return types
68+
@test compare(x->getindex(x,1), AT, zeros(Int, 2))
69+
@test compare(x->setindex!(x,1,1), AT, zeros(Int, 2))
70+
end
6571
end
6672
end

0 commit comments

Comments
 (0)