Skip to content

Commit 5461475

Browse files
authored
Add a simpler CuRefValue. (#2645)
1 parent 3250f1e commit 5461475

File tree

6 files changed

+154
-69
lines changed

6 files changed

+154
-69
lines changed

src/CUDA.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -72,6 +72,7 @@ include("device/quirks.jl")
7272
# array essentials
7373
include("memory.jl")
7474
include("array.jl")
75+
include("refpointer.jl")
7576

7677
# compiler libraries
7778
include("../lib/cupti/CUPTI.jl")

src/array.jl

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -45,11 +45,11 @@ end
4545
# these are stored with a selector at the end (handled by Julia).
4646
# 3. bitstype unions (`Union{Int, Float32}`, etc)
4747
# these are stored contiguously and require a selector array (handled by us)
48-
function check_eltype(T)
48+
@inline function check_eltype(name, T)
4949
if !Base.allocatedinline(T)
5050
explanation = explain_eltype(T)
5151
error("""
52-
CuArray only supports element types that are allocated inline.
52+
$name only supports element types that are allocated inline.
5353
$explanation""")
5454
end
5555
end
@@ -63,7 +63,7 @@ mutable struct CuArray{T,N,M} <: AbstractGPUArray{T,N}
6363
dims::Dims{N}
6464

6565
function CuArray{T,N,M}(::UndefInitializer, dims::Dims{N}) where {T,N,M}
66-
check_eltype(T)
66+
check_eltype("CuArray", T)
6767
maxsize = prod(dims) * sizeof(T)
6868
bufsize = if Base.isbitsunion(T)
6969
# type tag array past the data
@@ -82,7 +82,7 @@ mutable struct CuArray{T,N,M} <: AbstractGPUArray{T,N}
8282

8383
function CuArray{T,N}(data::DataRef{Managed{M}}, dims::Dims{N};
8484
maxsize::Int=prod(dims) * sizeof(T), offset::Int=0) where {T,N,M}
85-
check_eltype(T)
85+
check_eltype("CuArray", T)
8686
obj = new{T,N,M}(data, maxsize, offset, dims)
8787
finalizer(unsafe_free!, obj)
8888
return obj

src/compiler/execution.jl

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -168,12 +168,12 @@ end
168168
# Note that it isn't safe to use unified or heterogeneous memory to support a
169169
# mutable Ref, because there's no guarantee that the memory would be kept alive
170170
# long enough (especially with broadcast using ephemeral Refs for scalar args).
171-
struct CuRefValue{T} <: Ref{T}
171+
struct KernelRefValue{T} <: Ref{T}
172172
val::T
173173
end
174-
Base.getindex(r::CuRefValue{T}) where T = r.val
174+
Base.getindex(r::KernelRefValue{T}) where T = r.val
175175
Adapt.adapt_structure(to::KernelAdaptor, ref::Base.RefValue) =
176-
CuRefValue(adapt(to, ref[]))
176+
KernelRefValue(adapt(to, ref[]))
177177

178178
# broadcast sometimes passes a ref(type), resulting in a GPU-incompatible DataType box.
179179
# avoid that by using a special kind of ref that knows about the boxed type.

src/pointer.jl

Lines changed: 1 addition & 58 deletions
Original file line numberDiff line numberDiff line change
@@ -207,68 +207,11 @@ Base.:(+)(x::Integer, y::CuArrayPtr) = y + x
207207

208208

209209
#
210-
# CUDA reference objects
210+
# CUDA reference objects (forward declaration)
211211
#
212212

213213
if sizeof(Ptr{Cvoid}) == 8
214214
primitive type CuRef{T} 64 end
215215
else
216216
primitive type CuRef{T} 32 end
217217
end
218-
219-
# general methods for CuRef{T} type
220-
Base.eltype(x::Type{<:CuRef{T}}) where {T} = @isdefined(T) ? T : Any
221-
222-
Base.convert(::Type{CuRef{T}}, x::CuRef{T}) where {T} = x
223-
224-
# conversion or the actual ccall
225-
Base.unsafe_convert(::Type{CuRef{T}}, x::CuRef{T}) where {T} = Base.bitcast(CuRef{T}, Base.unsafe_convert(CuPtr{T}, x))
226-
Base.unsafe_convert(::Type{CuRef{T}}, x) where {T} = Base.bitcast(CuRef{T}, Base.unsafe_convert(CuPtr{T}, x))
227-
## `@gcsafe_ccall` results in "double conversions" (remove this once `ccall` does `gcsafe`)
228-
Base.unsafe_convert(::Type{CuPtr{T}}, x::CuRef{T}) where {T} = x
229-
230-
# CuRef from literal pointer
231-
Base.convert(::Type{CuRef{T}}, x::CuPtr{T}) where {T} = x
232-
233-
# indirect constructors using CuRef
234-
CuRef(x::Any) = CuRefArray(CuArray([x]))
235-
CuRef{T}(x) where {T} = CuRefArray{T}(CuArray(T[x]))
236-
CuRef{T}() where {T} = CuRefArray(CuArray{T}(undef, 1))
237-
Base.convert(::Type{CuRef{T}}, x) where {T} = CuRef{T}(x)
238-
239-
240-
## CuRef object backed by a CUDA array at index i
241-
242-
struct CuRefArray{T,A<:AbstractArray{T}} <: Ref{T}
243-
x::A
244-
i::Int
245-
CuRefArray{T,A}(x,i) where {T,A<:AbstractArray{T}} = new(x,i)
246-
end
247-
CuRefArray{T}(x::AbstractArray{T}, i::Int=1) where {T} = CuRefArray{T,typeof(x)}(x, i)
248-
CuRefArray(x::AbstractArray{T}, i::Int=1) where {T} = CuRefArray{T}(x, i)
249-
Base.convert(::Type{CuRef{T}}, x::AbstractArray{T}) where {T} = CuRefArray(x, 1)
250-
Base.convert(::Type{CuRef{T}}, x::CuRefArray{T}) where {T} = x
251-
252-
function Base.unsafe_convert(P::Type{CuPtr{T}}, b::CuRefArray{T}) where T
253-
return pointer(b.x, b.i)
254-
end
255-
function Base.unsafe_convert(P::Type{CuPtr{Any}}, b::CuRefArray{Any})
256-
return convert(P, pointer(b.x, b.i))
257-
end
258-
Base.unsafe_convert(::Type{CuPtr{Cvoid}}, b::CuRefArray{T}) where {T} =
259-
convert(CuPtr{Cvoid}, Base.unsafe_convert(CuPtr{T}, b))
260-
261-
function Base.getindex(gpu::CuRefArray{T}) where {T}
262-
cpu = Ref{T}()
263-
GC.@preserve cpu begin
264-
cpu_ptr = Base.unsafe_convert(Ptr{T}, cpu)
265-
gpu_ptr = pointer(gpu.x, gpu.i)
266-
unsafe_copyto!(cpu_ptr, gpu_ptr, 1)
267-
end
268-
cpu[]
269-
end
270-
271-
272-
## Union with all CuRef 'subtypes'
273-
274-
const CuRefs{T} = Union{CuPtr{T}, CuRefArray{T}}

src/refpointer.jl

Lines changed: 141 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,141 @@
1+
# reference objects
2+
3+
abstract type AbstractCuRef{T} <: Ref{T} end
4+
5+
## opaque reference type
6+
##
7+
## we use a concrete CuRef type that actual references can be (no-op) converted to, without
8+
## actually being a subtype of CuRef. This is necessary so that `CuRef` can be used in
9+
## `ccall` signatures; which Base solves by special-casing `Ref` handing in `ccall.cpp`.
10+
# forward declaration in pointer.jl
11+
12+
# general methods for CuRef{T} type
13+
Base.eltype(x::Type{<:CuRef{T}}) where {T} = @isdefined(T) ? T : Any
14+
15+
Base.convert(::Type{CuRef{T}}, x::CuRef{T}) where {T} = x
16+
17+
# conversion or the actual ccall
18+
Base.unsafe_convert(::Type{CuRef{T}}, x::CuRef{T}) where {T} = Base.bitcast(CuRef{T}, Base.unsafe_convert(CuPtr{T}, x))
19+
Base.unsafe_convert(::Type{CuRef{T}}, x) where {T} = Base.bitcast(CuRef{T}, Base.unsafe_convert(CuPtr{T}, x))
20+
## `@gcsafe_ccall` results in "double conversions" (remove this once `ccall` does `gcsafe`)
21+
Base.unsafe_convert(::Type{CuPtr{T}}, x::CuRef{T}) where {T} = x
22+
23+
# CuRef from literal pointer
24+
Base.convert(::Type{CuRef{T}}, x::CuPtr{T}) where {T} = x
25+
26+
# indirect constructors using CuRef
27+
CuRef(x::Any) = CuRefValue(x)
28+
CuRef{T}(x) where {T} = CuRefValue{T}(x)
29+
CuRef{T}() where {T} = CuRefValue{T}()
30+
Base.convert(::Type{CuRef{T}}, x) where {T} = CuRef{T}(x)
31+
32+
# idempotency
33+
Base.convert(::Type{CuRef{T}}, x::AbstractCuRef{T}) where {T} = x
34+
35+
36+
## reference backed by a single allocation
37+
38+
# TODO: maintain a small global cache of reference boxes
39+
40+
mutable struct CuRefValue{T} <: AbstractCuRef{T}
41+
buf::Managed{DeviceMemory}
42+
43+
function CuRefValue{T}() where {T}
44+
check_eltype("CuRef", T)
45+
buf = pool_alloc(DeviceMemory, sizeof(T))
46+
obj = new(buf)
47+
finalizer(obj) do _
48+
pool_free(buf)
49+
end
50+
return obj
51+
end
52+
end
53+
function CuRefValue{T}(x::T) where {T}
54+
ref = CuRefValue{T}()
55+
ref[] = x
56+
return ref
57+
end
58+
CuRefValue{T}(x) where {T} = CuRefValue{T}(convert(T, x))
59+
CuRefValue(x::T) where {T} = CuRefValue{T}(x)
60+
61+
Base.unsafe_convert(::Type{CuPtr{T}}, b::CuRefValue{T}) where {T} = convert(CuPtr{T}, b.buf)
62+
Base.unsafe_convert(P::Type{CuPtr{Any}}, b::CuRefValue{Any}) = convert(P, b.buf)
63+
Base.unsafe_convert(::Type{CuPtr{Cvoid}}, b::CuRefValue{T}) where {T} =
64+
convert(CuPtr{Cvoid}, b.buf)
65+
66+
function Base.setindex!(gpu::CuRefValue{T}, x::T) where {T}
67+
cpu = Ref(x)
68+
GC.@preserve cpu begin
69+
cpu_ptr = Base.unsafe_convert(Ptr{T}, cpu)
70+
gpu_ptr = Base.unsafe_convert(CuPtr{T}, gpu)
71+
unsafe_copyto!(gpu_ptr, cpu_ptr, 1; async=true)
72+
end
73+
return gpu
74+
end
75+
76+
function Base.getindex(gpu::CuRefValue{T}) where {T}
77+
# synchronize first to maximize time spent executing Julia code
78+
synchronize(gpu.buf)
79+
80+
cpu = Ref{T}()
81+
GC.@preserve cpu begin
82+
cpu_ptr = Base.unsafe_convert(Ptr{T}, cpu)
83+
gpu_ptr = Base.unsafe_convert(CuPtr{T}, gpu)
84+
unsafe_copyto!(cpu_ptr, gpu_ptr, 1; async=false)
85+
end
86+
cpu[]
87+
end
88+
89+
function Base.show(io::IO, x::CuRefValue{T}) where {T}
90+
print(io, "CuRefValue{$T}(")
91+
print(io, x[])
92+
print(io, ")")
93+
end
94+
95+
96+
## reference backed by a CUDA array at index i
97+
98+
struct CuRefArray{T,A<:AbstractArray{T}} <: AbstractCuRef{T}
99+
x::A
100+
i::Int
101+
CuRefArray{T,A}(x,i) where {T,A<:AbstractArray{T}} = new(x,i)
102+
end
103+
CuRefArray{T}(x::AbstractArray{T}, i::Int=1) where {T} = CuRefArray{T,typeof(x)}(x, i)
104+
CuRefArray(x::AbstractArray{T}, i::Int=1) where {T} = CuRefArray{T}(x, i)
105+
106+
Base.convert(::Type{CuRef{T}}, x::AbstractArray{T}) where {T} = CuRefArray(x, 1)
107+
Base.convert(::Type{CuRef{T}}, x::CuRefArray{T}) where {T} = x
108+
109+
Base.unsafe_convert(P::Type{CuPtr{T}}, b::CuRefArray{T}) where {T} = pointer(b.x, b.i)
110+
Base.unsafe_convert(P::Type{CuPtr{Any}}, b::CuRefArray{Any}) = convert(P, pointer(b.x, b.i))
111+
Base.unsafe_convert(::Type{CuPtr{Cvoid}}, b::CuRefArray{T}) where {T} =
112+
convert(CuPtr{Cvoid}, Base.unsafe_convert(CuPtr{T}, b))
113+
114+
function Base.setindex!(gpu::CuRefArray{T}, x::T) where {T}
115+
cpu = Ref(x)
116+
GC.@preserve cpu begin
117+
cpu_ptr = Base.unsafe_convert(Ptr{T}, cpu)
118+
gpu_ptr = pointer(gpu.x, gpu.i)
119+
unsafe_copyto!(gpu_ptr, cpu_ptr, 1; async=true)
120+
end
121+
return gpu
122+
end
123+
124+
function Base.getindex(gpu::CuRefArray{T}) where {T}
125+
# synchronize first to maximize time spent executing Julia code
126+
synchronize(gpu.x)
127+
128+
cpu = Ref{T}()
129+
GC.@preserve cpu begin
130+
cpu_ptr = Base.unsafe_convert(Ptr{T}, cpu)
131+
gpu_ptr = pointer(gpu.x, gpu.i)
132+
unsafe_copyto!(cpu_ptr, gpu_ptr, 1; async=false)
133+
end
134+
cpu[]
135+
end
136+
137+
function Base.show(io::IO, x::CuRefArray{T}) where {T}
138+
print(io, "CuRefArray{$T}(")
139+
print(io, x[])
140+
print(io, ")")
141+
end

test/libraries/cublas/level1.jl

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -141,15 +141,15 @@ k = 13
141141
@test CUBLAS.iamin(ca) == 3
142142
result_type = CUBLAS.version() >= v"12.0" ? Int64 : Cint
143143
result = CuRef{result_type}(0)
144-
result = CUBLAS.iamax(ca, result)
145-
@test BLAS.iamax(a) == only(Array(result.x))
144+
CUBLAS.iamax(ca, result)
145+
@test BLAS.iamax(a) == result[]
146146
end
147147
@testset "nrm2 with result" begin
148148
x = rand(T, m)
149149
dx = CuArray(x)
150150
result = CuRef{real(T)}(zero(real(T)))
151-
result = CUBLAS.nrm2(dx, result)
152-
@test norm(x) only(Array(result.x))
151+
CUBLAS.nrm2(dx, result)
152+
@test norm(x) result[]
153153
end
154154
end # level 1 testset
155155
@testset for T in [Float16, ComplexF16]

0 commit comments

Comments
 (0)