Skip to content

Commit 449eed4

Browse files
committed
Use Adapt.jl to properly convert arguments to the device.
1 parent a0fbd02 commit 449eed4

File tree

1 file changed

+15
-6
lines changed

1 file changed

+15
-6
lines changed

src/reference.jl

Lines changed: 15 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -5,9 +5,11 @@
55

66
module JLArrays
77

8+
export JLArray
9+
810
using GPUArrays
911

10-
export JLArray
12+
using Adapt
1113

1214

1315
#
@@ -52,12 +54,19 @@ function JLKernelContext(ctx::JLKernelContext, threadidx::Int)
5254
)
5355
end
5456

55-
to_device(ctx, x::Tuple) = to_device.(Ref(ctx), x)
56-
to_device(ctx, x) = x
57+
struct Adaptor end
58+
jlconvert(arg) = adapt(Adaptor(), arg)
59+
60+
# FIXME: add Ref to Adapt.jl (but make sure it doesn't cause ambiguities with CUDAnative's)
61+
struct JlRefValue{T} <: Ref{T}
62+
x::T
63+
end
64+
Base.getindex(r::JlRefValue) = r.x
65+
Adapt.adapt_structure(to::Adaptor, r::Base.RefValue) = JlRefValue(adapt(to, r[]))
5766

5867
function GPUArrays.gpu_call(::JLBackend, f, args...; blocks::Int, threads::Int)
5968
ctx = JLKernelContext(threads, blocks)
60-
device_args = to_device.(Ref(ctx), args)
69+
device_args = jlconvert.(args)
6170
tasks = Array{Task}(undef, threads)
6271
@allowscalar for blockidx in 1:blocks
6372
ctx.blockidx = blockidx
@@ -267,8 +276,8 @@ GPUArrays.device(x::JLArray) = JLDevice()
267276

268277
GPUArrays.backend(::Type{<:JLArray}) = JLBackend()
269278

270-
to_device(ctx, x::JLArray{T,N}) where {T,N} = JLDeviceArray{T,N}(x.data, x.dims)
271-
to_device(ctx, x::Base.RefValue{<: JLArray}) = Base.RefValue(to_device(ctx, x[]))
279+
Adapt.adapt_storage(::Adaptor, x::JLArray{T,N}) where {T,N} =
280+
JLDeviceArray{T,N}(x.data, x.dims)
272281

273282
GPUArrays.unsafe_reinterpret(::Type{T}, A::JLArray, size::Tuple) where T =
274283
reshape(reinterpret(T, A.data), size)

0 commit comments

Comments
 (0)