Skip to content

Commit 3d42ca2

Browse files
Use GPUToolbox.jl (#2646)
1 parent 236643b commit 3d42ca2

File tree

6 files changed

+10
-188
lines changed

6 files changed

+10
-188
lines changed

Project.toml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@ DataFrames = "a93c6f00-e57d-5684-b7b6-d8193f3e46c0"
1515
ExprTools = "e2ba6199-217a-4e67-a87a-7c52f15ade04"
1616
GPUArrays = "0c68f7d7-f131-5f86-a1c3-88cf8149b2d7"
1717
GPUCompiler = "61eb1bfa-7361-4325-ad38-22787b887f55"
18+
GPUToolbox = "096a3bc2-3ced-46d0-87f4-dd12716f4bfc"
1819
KernelAbstractions = "63c18a36-062a-441e-b654-da1e3ab1ce7c"
1920
LLVM = "929cbde3-209d-540e-8aea-75f648917ca0"
2021
LLVMLoopInfo = "8b046642-f1f6-4319-8d3c-209ddc03c586"
@@ -61,6 +62,7 @@ EnzymeCore = "0.8.2"
6162
ExprTools = "0.1"
6263
GPUArrays = "11.2.1"
6364
GPUCompiler = "0.24, 0.25, 0.26, 0.27, 1"
65+
GPUToolbox = "0.1"
6466
KernelAbstractions = "0.9.2"
6567
LLVM = "9.1"
6668
LLVMLoopInfo = "1"

lib/utils/APIUtils.jl

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,9 @@ using LLVM
66
using LLVM.Interop
77

88
# helpers that facilitate working with CUDA APIs
9+
using GPUToolbox: @checked, @debug_ccall, @gcsafe_ccall
10+
export @checked, @debug_ccall, @gcsafe_ccall
11+
912
include("call.jl")
1013
include("enum.jl")
1114
include("threading.jl")

lib/utils/call.jl

Lines changed: 1 addition & 141 deletions
Original file line numberDiff line numberDiff line change
@@ -1,50 +1,6 @@
11
# utilities for calling foreign functionality more conveniently
22

3-
export @checked, with_workspace, with_workspaces,
4-
@debug_ccall, @gcsafe_ccall
5-
6-
7-
## function wrapper for checking the return value of a function
8-
9-
"""
10-
@checked function foo(...)
11-
rv = ...
12-
return rv
13-
end
14-
15-
Macro for wrapping a function definition returning a status code. Two versions of the
16-
function will be generated: `foo`, with the function execution wrapped by an invocation of
17-
the `check` function (to be implemented by the caller of this macro), and `unchecked_foo`
18-
where no such invocation is present and the status code is returned to the caller.
19-
"""
20-
macro checked(ex)
21-
# parse the function definition
22-
@assert Meta.isexpr(ex, :function)
23-
sig = ex.args[1]
24-
@assert Meta.isexpr(sig, :call)
25-
body = ex.args[2]
26-
@assert Meta.isexpr(body, :block)
27-
28-
# make sure these functions are inlined
29-
pushfirst!(body.args, Expr(:meta, :inline))
30-
31-
# generate a "safe" version that performs a check
32-
safe_body = quote
33-
@inline
34-
check() do
35-
$body
36-
end
37-
end
38-
safe_sig = Expr(:call, sig.args[1], sig.args[2:end]...)
39-
safe_def = Expr(:function, safe_sig, safe_body)
40-
41-
# generate a "unchecked" version that returns the error code instead
42-
unchecked_sig = Expr(:call, Symbol("unchecked_", sig.args[1]), sig.args[2:end]...)
43-
unchecked_def = Expr(:function, unchecked_sig, body)
44-
45-
return esc(:($safe_def, $unchecked_def))
46-
end
47-
3+
export with_workspace, with_workspaces
484

495
## wrapper for foreign functionality that requires a workspace buffer
506

@@ -138,99 +94,3 @@ function with_workspaces(f::Base.Callable,
13894
end
13995
end
14096
end
141-
142-
143-
## version of ccall that prints the ccall, its arguments and its return value
144-
145-
macro debug_ccall(ex)
146-
@assert Meta.isexpr(ex, :(::))
147-
call, ret = ex.args
148-
@assert Meta.isexpr(call, :call)
149-
target, argexprs... = call.args
150-
args = map(argexprs) do argexpr
151-
@assert Meta.isexpr(argexpr, :(::))
152-
argexpr.args[1]
153-
end
154-
155-
ex = Expr(:macrocall, Symbol("@ccall"), __source__, ex)
156-
157-
# avoid task switches
158-
io = :(Core.stdout)
159-
160-
quote
161-
print($io, $(string(target)), '(')
162-
for (i, arg) in enumerate(($(map(esc, args)...),))
163-
i > 1 && print($io, ", ")
164-
render_arg($io, arg)
165-
end
166-
print($io, ')')
167-
168-
rv = $(esc(ex))
169-
170-
println($io, " = ", rv)
171-
for (i, arg) in enumerate(($(map(esc, args)...),))
172-
if arg isa Base.RefValue
173-
println($io, " $i: ", arg[])
174-
end
175-
end
176-
rv
177-
end
178-
end
179-
180-
render_arg(io, arg) = print(io, arg)
181-
render_arg(io, arg::AbstractArray) = summary(io, arg)
182-
render_arg(io, arg::Base.RefValue{T}) where {T} = print(io, "Ref{", T, "}")
183-
184-
185-
## version of ccall that calls jl_gc_safe_enter|leave around the inner ccall
186-
187-
# TODO: replace with JuliaLang/julia#49933 once merged
188-
189-
function ccall_macro_lower(func, rettype, types, args, nreq)
190-
# instead of re-using ccall or Expr(:foreigncall) to perform argument conversion,
191-
# we need to do so ourselves in order to insert a jl_gc_safe_enter|leave
192-
# just around the inner ccall
193-
194-
cconvert_exprs = []
195-
cconvert_args = []
196-
for (typ, arg) in zip(types, args)
197-
var = gensym("$(func)_cconvert")
198-
push!(cconvert_args, var)
199-
push!(cconvert_exprs, :($var = Base.cconvert($(esc(typ)), $(esc(arg)))))
200-
end
201-
202-
unsafe_convert_exprs = []
203-
unsafe_convert_args = []
204-
for (typ, arg) in zip(types, cconvert_args)
205-
var = gensym("$(func)_unsafe_convert")
206-
push!(unsafe_convert_args, var)
207-
push!(unsafe_convert_exprs, :($var = Base.unsafe_convert($(esc(typ)), $arg)))
208-
end
209-
210-
call = quote
211-
$(unsafe_convert_exprs...)
212-
213-
gc_state = @ccall(jl_gc_safe_enter()::Int8)
214-
ret = ccall($(esc(func)), $(esc(rettype)), $(Expr(:tuple, map(esc, types)...)),
215-
$(unsafe_convert_args...))
216-
@ccall(jl_gc_safe_leave(gc_state::Int8)::Cvoid)
217-
ret
218-
end
219-
220-
quote
221-
@inline
222-
$(cconvert_exprs...)
223-
GC.@preserve $(cconvert_args...) $(call)
224-
end
225-
end
226-
227-
"""
228-
@gcsafe_ccall ...
229-
230-
Call a foreign function just like `@ccall`, but marking it safe for the GC to run. This is
231-
useful for functions that may block, so that the GC isn't blocked from running, but may also
232-
be required to prevent deadlocks (see JuliaGPU/CUDA.jl#2261).
233-
"""
234-
macro gcsafe_ccall(expr)
235-
ccall_macro_lower(Base.ccall_macro_parse(expr)...)
236-
end

src/CUDA.jl

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,8 @@ using GPUCompiler
44

55
using GPUArrays
66

7+
using GPUToolbox: SimpleVersion, @sv_str
8+
79
using LLVM
810
using LLVM.Interop
911
using Core: LLVMPtr

src/device/intrinsics/cooperative_groups.jl

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,8 @@ using ..LLVMLoopInfo
3232

3333
using Core: LLVMPtr
3434

35+
using GPUToolbox: @sv_str
36+
3537
const cg_debug = false
3638
if cg_debug
3739
cg_assert(x) = @cuassert x

src/device/intrinsics/version.jl

Lines changed: 0 additions & 47 deletions
Original file line numberDiff line numberDiff line change
@@ -1,52 +1,5 @@
11
# device intrinsics for querying the compute SimpleVersion and PTX ISA version
22

3-
4-
## a GPU-compatible version number
5-
6-
export SimpleVersion, @sv_str
7-
8-
struct SimpleVersion
9-
major::UInt32
10-
minor::UInt32
11-
12-
SimpleVersion(major, minor=0) = new(major, minor)
13-
end
14-
15-
function Base.tryparse(::Type{SimpleVersion}, v::AbstractString)
16-
parts = split(v, ".")
17-
1 <= length(parts) <= 2 || return nothing
18-
19-
int_parts = map(parts) do part
20-
tryparse(Int, part)
21-
end
22-
any(isnothing, int_parts) && return nothing
23-
24-
SimpleVersion(int_parts...)
25-
end
26-
27-
function Base.parse(::Type{SimpleVersion}, v::AbstractString)
28-
ver = tryparse(SimpleVersion, v)
29-
ver === nothing && throw(ArgumentError("invalid SimpleVersion string: '$v'"))
30-
return ver
31-
end
32-
33-
SimpleVersion(v::AbstractString) = parse(SimpleVersion, v)
34-
35-
@inline function Base.isless(a::SimpleVersion, b::SimpleVersion)
36-
(a.major < b.major) && return true
37-
(a.major > b.major) && return false
38-
(a.minor < b.minor) && return true
39-
(a.minor > b.minor) && return false
40-
return false
41-
end
42-
43-
macro sv_str(str)
44-
SimpleVersion(str)
45-
end
46-
47-
48-
## accessors for the compute SimpleVersion and PTX ISA version
49-
503
export compute_capability, ptx_isa_version
514

525
for var in ["sm_major", "sm_minor", "ptx_major", "ptx_minor"]

0 commit comments

Comments
 (0)