|
5 | 5 |
|
6 | 6 | module JLArrays
|
7 | 7 |
|
| 8 | +export JLArray |
| 9 | + |
8 | 10 | using GPUArrays
|
9 | 11 |
|
10 |
| -export JLArray |
| 12 | +using Adapt |
11 | 13 |
|
12 | 14 |
|
13 | 15 | #
|
@@ -52,12 +54,19 @@ function JLKernelContext(ctx::JLKernelContext, threadidx::Int)
|
52 | 54 | )
|
53 | 55 | end
|
54 | 56 |
|
55 |
| -to_device(ctx, x::Tuple) = to_device.(Ref(ctx), x) |
56 |
| -to_device(ctx, x) = x |
| 57 | +struct Adaptor end |
| 58 | +jlconvert(arg) = adapt(Adaptor(), arg) |
| 59 | + |
| 60 | +# FIXME: add Ref to Adapt.jl (but make sure it doesn't cause ambiguities with CUDAnative's) |
| 61 | +struct JlRefValue{T} <: Ref{T} |
| 62 | + x::T |
| 63 | +end |
| 64 | +Base.getindex(r::JlRefValue) = r.x |
| 65 | +Adapt.adapt_structure(to::Adaptor, r::Base.RefValue) = JlRefValue(adapt(to, r[])) |
57 | 66 |
|
58 | 67 | function GPUArrays.gpu_call(::JLBackend, f, args...; blocks::Int, threads::Int)
|
59 | 68 | ctx = JLKernelContext(threads, blocks)
|
60 |
| - device_args = to_device.(Ref(ctx), args) |
| 69 | + device_args = jlconvert.(args) |
61 | 70 | tasks = Array{Task}(undef, threads)
|
62 | 71 | @allowscalar for blockidx in 1:blocks
|
63 | 72 | ctx.blockidx = blockidx
|
@@ -267,8 +276,8 @@ GPUArrays.device(x::JLArray) = JLDevice()
|
267 | 276 |
|
268 | 277 | GPUArrays.backend(::Type{<:JLArray}) = JLBackend()
|
269 | 278 |
|
270 |
| -to_device(ctx, x::JLArray{T,N}) where {T,N} = JLDeviceArray{T,N}(x.data, x.dims) |
271 |
| -to_device(ctx, x::Base.RefValue{<: JLArray}) = Base.RefValue(to_device(ctx, x[])) |
| 279 | +Adapt.adapt_storage(::Adaptor, x::JLArray{T,N}) where {T,N} = |
| 280 | + JLDeviceArray{T,N}(x.data, x.dims) |
272 | 281 |
|
273 | 282 | GPUArrays.unsafe_reinterpret(::Type{T}, A::JLArray, size::Tuple) where T =
|
274 | 283 | reshape(reinterpret(T, A.data), size)
|
|
0 commit comments