|
| 1 | +# reference objects |
| 2 | + |
| 3 | +abstract type AbstractCuRef{T} <: Ref{T} end |
| 4 | + |
| 5 | +## opaque reference type |
| 6 | +## |
| 7 | +## we use a concrete CuRef type that actual references can be (no-op) converted to, without |
| 8 | +## actually being a subtype of CuRef. This is necessary so that `CuRef` can be used in |
| 9 | +## `ccall` signatures; which Base solves by special-casing `Ref` handing in `ccall.cpp`. |
| 10 | +# forward declaration in pointer.jl |
| 11 | + |
| 12 | +# general methods for CuRef{T} type |
| 13 | +Base.eltype(x::Type{<:CuRef{T}}) where {T} = @isdefined(T) ? T : Any |
| 14 | + |
| 15 | +Base.convert(::Type{CuRef{T}}, x::CuRef{T}) where {T} = x |
| 16 | + |
| 17 | +# conversion or the actual ccall |
| 18 | +Base.unsafe_convert(::Type{CuRef{T}}, x::CuRef{T}) where {T} = Base.bitcast(CuRef{T}, Base.unsafe_convert(CuPtr{T}, x)) |
| 19 | +Base.unsafe_convert(::Type{CuRef{T}}, x) where {T} = Base.bitcast(CuRef{T}, Base.unsafe_convert(CuPtr{T}, x)) |
| 20 | +## `@gcsafe_ccall` results in "double conversions" (remove this once `ccall` does `gcsafe`) |
| 21 | +Base.unsafe_convert(::Type{CuPtr{T}}, x::CuRef{T}) where {T} = x |
| 22 | + |
| 23 | +# CuRef from literal pointer |
| 24 | +Base.convert(::Type{CuRef{T}}, x::CuPtr{T}) where {T} = x |
| 25 | + |
| 26 | +# indirect constructors using CuRef |
| 27 | +CuRef(x::Any) = CuRefValue(x) |
| 28 | +CuRef{T}(x) where {T} = CuRefValue{T}(x) |
| 29 | +CuRef{T}() where {T} = CuRefValue{T}() |
| 30 | +Base.convert(::Type{CuRef{T}}, x) where {T} = CuRef{T}(x) |
| 31 | + |
| 32 | +# idempotency |
| 33 | +Base.convert(::Type{CuRef{T}}, x::AbstractCuRef{T}) where {T} = x |
| 34 | + |
| 35 | + |
| 36 | +## reference backed by a single allocation |
| 37 | + |
| 38 | +# TODO: maintain a small global cache of reference boxes |
| 39 | + |
| 40 | +mutable struct CuRefValue{T} <: AbstractCuRef{T} |
| 41 | + buf::Managed{DeviceMemory} |
| 42 | + |
| 43 | + function CuRefValue{T}() where {T} |
| 44 | + check_eltype("CuRef", T) |
| 45 | + buf = pool_alloc(DeviceMemory, sizeof(T)) |
| 46 | + obj = new(buf) |
| 47 | + finalizer(obj) do _ |
| 48 | + pool_free(buf) |
| 49 | + end |
| 50 | + return obj |
| 51 | + end |
| 52 | +end |
| 53 | +function CuRefValue{T}(x::T) where {T} |
| 54 | + ref = CuRefValue{T}() |
| 55 | + ref[] = x |
| 56 | + return ref |
| 57 | +end |
| 58 | +CuRefValue{T}(x) where {T} = CuRefValue{T}(convert(T, x)) |
| 59 | +CuRefValue(x::T) where {T} = CuRefValue{T}(x) |
| 60 | + |
| 61 | +Base.unsafe_convert(::Type{CuPtr{T}}, b::CuRefValue{T}) where {T} = convert(CuPtr{T}, b.buf) |
| 62 | +Base.unsafe_convert(P::Type{CuPtr{Any}}, b::CuRefValue{Any}) = convert(P, b.buf) |
| 63 | +Base.unsafe_convert(::Type{CuPtr{Cvoid}}, b::CuRefValue{T}) where {T} = |
| 64 | + convert(CuPtr{Cvoid}, b.buf) |
| 65 | + |
| 66 | +function Base.setindex!(gpu::CuRefValue{T}, x::T) where {T} |
| 67 | + cpu = Ref(x) |
| 68 | + GC.@preserve cpu begin |
| 69 | + cpu_ptr = Base.unsafe_convert(Ptr{T}, cpu) |
| 70 | + gpu_ptr = Base.unsafe_convert(CuPtr{T}, gpu) |
| 71 | + unsafe_copyto!(gpu_ptr, cpu_ptr, 1; async=true) |
| 72 | + end |
| 73 | + return gpu |
| 74 | +end |
| 75 | + |
| 76 | +function Base.getindex(gpu::CuRefValue{T}) where {T} |
| 77 | + # synchronize first to maximize time spent executing Julia code |
| 78 | + synchronize(gpu.buf) |
| 79 | + |
| 80 | + cpu = Ref{T}() |
| 81 | + GC.@preserve cpu begin |
| 82 | + cpu_ptr = Base.unsafe_convert(Ptr{T}, cpu) |
| 83 | + gpu_ptr = Base.unsafe_convert(CuPtr{T}, gpu) |
| 84 | + unsafe_copyto!(cpu_ptr, gpu_ptr, 1; async=false) |
| 85 | + end |
| 86 | + cpu[] |
| 87 | +end |
| 88 | + |
| 89 | +function Base.show(io::IO, x::CuRefValue{T}) where {T} |
| 90 | + print(io, "CuRefValue{$T}(") |
| 91 | + print(io, x[]) |
| 92 | + print(io, ")") |
| 93 | +end |
| 94 | + |
| 95 | + |
| 96 | +## reference backed by a CUDA array at index i |
| 97 | + |
| 98 | +struct CuRefArray{T,A<:AbstractArray{T}} <: AbstractCuRef{T} |
| 99 | + x::A |
| 100 | + i::Int |
| 101 | + CuRefArray{T,A}(x,i) where {T,A<:AbstractArray{T}} = new(x,i) |
| 102 | +end |
| 103 | +CuRefArray{T}(x::AbstractArray{T}, i::Int=1) where {T} = CuRefArray{T,typeof(x)}(x, i) |
| 104 | +CuRefArray(x::AbstractArray{T}, i::Int=1) where {T} = CuRefArray{T}(x, i) |
| 105 | + |
| 106 | +Base.convert(::Type{CuRef{T}}, x::AbstractArray{T}) where {T} = CuRefArray(x, 1) |
| 107 | +Base.convert(::Type{CuRef{T}}, x::CuRefArray{T}) where {T} = x |
| 108 | + |
| 109 | +Base.unsafe_convert(P::Type{CuPtr{T}}, b::CuRefArray{T}) where {T} = pointer(b.x, b.i) |
| 110 | +Base.unsafe_convert(P::Type{CuPtr{Any}}, b::CuRefArray{Any}) = convert(P, pointer(b.x, b.i)) |
| 111 | +Base.unsafe_convert(::Type{CuPtr{Cvoid}}, b::CuRefArray{T}) where {T} = |
| 112 | + convert(CuPtr{Cvoid}, Base.unsafe_convert(CuPtr{T}, b)) |
| 113 | + |
| 114 | +function Base.setindex!(gpu::CuRefArray{T}, x::T) where {T} |
| 115 | + cpu = Ref(x) |
| 116 | + GC.@preserve cpu begin |
| 117 | + cpu_ptr = Base.unsafe_convert(Ptr{T}, cpu) |
| 118 | + gpu_ptr = pointer(gpu.x, gpu.i) |
| 119 | + unsafe_copyto!(gpu_ptr, cpu_ptr, 1; async=true) |
| 120 | + end |
| 121 | + return gpu |
| 122 | +end |
| 123 | + |
| 124 | +function Base.getindex(gpu::CuRefArray{T}) where {T} |
| 125 | + # synchronize first to maximize time spent executing Julia code |
| 126 | + synchronize(gpu.x) |
| 127 | + |
| 128 | + cpu = Ref{T}() |
| 129 | + GC.@preserve cpu begin |
| 130 | + cpu_ptr = Base.unsafe_convert(Ptr{T}, cpu) |
| 131 | + gpu_ptr = pointer(gpu.x, gpu.i) |
| 132 | + unsafe_copyto!(cpu_ptr, gpu_ptr, 1; async=false) |
| 133 | + end |
| 134 | + cpu[] |
| 135 | +end |
| 136 | + |
| 137 | +function Base.show(io::IO, x::CuRefArray{T}) where {T} |
| 138 | + print(io, "CuRefArray{$T}(") |
| 139 | + print(io, x[]) |
| 140 | + print(io, ")") |
| 141 | +end |
0 commit comments