Skip to content

Commit b6d3169

Browse files
author
William Moses
committed
wip
1 parent ad2fc22 commit b6d3169

File tree

2 files changed

+192
-9
lines changed

2 files changed

+192
-9
lines changed

Project.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@ version = "0.2.9"
66
[deps]
77
Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e"
88
CEnum = "fa961155-64e5-5f13-b03f-caf6b980ea82"
9+
CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba"
910
Downloads = "f43a241f-c20a-4ad4-852c-f6b1247861c6"
1011
Enzyme = "7da242da-08ed-463a-9acd-ee780be4f1d9"
1112
EnzymeCore = "f151be2c-9106-41f4-ab19-57ee4f262869"

ext/ReactantCUDAExt.jl

Lines changed: 191 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -8,35 +8,217 @@ using ReactantCore: @trace
88
using Adapt
99

1010
function Adapt.adapt_storage(::CUDA.KernelAdaptor, xs::TracedRArray{T,N}) where {T,N}
11-
CuDeviceArray{T,N,CUDA.AS.Global}(pointer(xs.mlir_data.value), size(xs))
11+
res = CuDeviceArray{T,N,CUDA.AS.Global}(Base.reinterpret(Core.LLVMPtr{T,CUDA.AS.Global}, xs.mlir_data.value.ptr), size(xs))
12+
@show res, xs
13+
return res
1214
end
1315

1416
const _kernel_instances = Dict{Any, Any}()
1517

18+
19+
20+
# compile to executable machine code
21+
function compile(job)
22+
# lower to PTX
23+
# TODO: on 1.9, this actually creates a context. cache those.
24+
modstr = JuliaContext() do ctx
25+
mod, meta = GPUCompiler.compile(:llvm, job)
26+
string(mod)
27+
end
28+
return modstr
29+
#=
30+
# check if we'll need the device runtime
31+
undefined_fs = filter(collect(functions(meta.ir))) do f
32+
isdeclaration(f) && !LLVM.isintrinsic(f)
33+
end
34+
intrinsic_fns = ["vprintf", "malloc", "free", "__assertfail",
35+
"__nvvm_reflect" #= TODO: should have been optimized away =#]
36+
needs_cudadevrt = !isempty(setdiff(LLVM.name.(undefined_fs), intrinsic_fns))
37+
38+
# prepare invocations of CUDA compiler tools
39+
ptxas_opts = String[]
40+
nvlink_opts = String[]
41+
## debug flags
42+
if Base.JLOptions().debug_level == 1
43+
push!(ptxas_opts, "--generate-line-info")
44+
elseif Base.JLOptions().debug_level >= 2
45+
push!(ptxas_opts, "--device-debug")
46+
push!(nvlink_opts, "--debug")
47+
end
48+
## relocatable device code
49+
if needs_cudadevrt
50+
push!(ptxas_opts, "--compile-only")
51+
end
52+
53+
ptx = job.config.params.ptx
54+
cap = job.config.params.cap
55+
arch = "sm_$(cap.major)$(cap.minor)"
56+
57+
# validate use of parameter memory
58+
argtypes = filter([KernelState, job.source.specTypes.parameters...]) do dt
59+
!isghosttype(dt) && !Core.Compiler.isconstType(dt)
60+
end
61+
param_usage = sum(sizeof, argtypes)
62+
param_limit = 4096
63+
if cap >= v"7.0" && ptx >= v"8.1"
64+
param_limit = 32764
65+
end
66+
if param_usage > param_limit
67+
msg = """Kernel invocation uses too much parameter memory.
68+
$(Base.format_bytes(param_usage)) exceeds the $(Base.format_bytes(param_limit)) limit imposed by sm_$(cap.major)$(cap.minor) / PTX v$(ptx.major).$(ptx.minor)."""
69+
70+
try
71+
details = "\n\nRelevant parameters:"
72+
73+
source_types = job.source.specTypes.parameters
74+
source_argnames = Base.method_argnames(job.source.def)
75+
while length(source_argnames) < length(source_types)
76+
# this is probably due to a trailing vararg; repeat its name
77+
push!(source_argnames, source_argnames[end])
78+
end
79+
80+
for (i, typ) in enumerate(source_types)
81+
if isghosttype(typ) || Core.Compiler.isconstType(typ)
82+
continue
83+
end
84+
name = source_argnames[i]
85+
details *= "\n [$(i-1)] $name::$typ uses $(Base.format_bytes(sizeof(typ)))"
86+
end
87+
details *= "\n"
88+
89+
if cap >= v"7.0" && ptx < v"8.1" && param_usage < 32764
90+
details *= "\nNote: use a newer CUDA to support more parameters on your device.\n"
91+
end
92+
93+
msg *= details
94+
catch err
95+
@error "Failed to analyze kernel parameter usage; please file an issue with a reproducer."
96+
end
97+
error(msg)
98+
end
99+
100+
# compile to machine code
101+
# NOTE: we use tempname since mktemp doesn't support suffixes, and mktempdir is slow
102+
ptx_input = tempname(cleanup=false) * ".ptx"
103+
ptxas_output = tempname(cleanup=false) * ".cubin"
104+
write(ptx_input, asm)
105+
106+
# we could use the driver's embedded JIT compiler, but that has several disadvantages:
107+
# 1. fixes and improvements are slower to arrive, by using `ptxas` we only need to
108+
# upgrade the toolkit to get a newer compiler;
109+
# 2. version checking is simpler, we otherwise need to use NVML to query the driver
110+
# version, which is hard to correlate to PTX JIT improvements;
111+
# 3. if we want to be able to use newer (minor upgrades) of the CUDA toolkit on an
112+
# older driver, we should use the newer compiler to ensure compatibility.
113+
append!(ptxas_opts, [
114+
"--verbose",
115+
"--gpu-name", arch,
116+
"--output-file", ptxas_output,
117+
ptx_input
118+
])
119+
proc, log = run_and_collect(`$(ptxas()) $ptxas_opts`)
120+
log = strip(log)
121+
if !success(proc)
122+
reason = proc.termsignal > 0 ? "ptxas received signal $(proc.termsignal)" :
123+
"ptxas exited with code $(proc.exitcode)"
124+
msg = "Failed to compile PTX code ($reason)"
125+
msg *= "\nInvocation arguments: $(join(ptxas_opts, ' '))"
126+
if !isempty(log)
127+
msg *= "\n" * log
128+
end
129+
msg *= "\nIf you think this is a bug, please file an issue and attach $(ptx_input)"
130+
if parse(Bool, get(ENV, "BUILDKITE", "false"))
131+
run(`buildkite-agent artifact upload $(ptx_input)`)
132+
end
133+
error(msg)
134+
elseif !isempty(log)
135+
@debug "PTX compiler log:\n" * log
136+
end
137+
rm(ptx_input)
138+
=#
139+
#=
140+
# link device libraries, if necessary
141+
#
142+
# this requires relocatable device code, which prevents certain optimizations and
143+
# hurts performance. as such, we only do so when absolutely necessary.
144+
# TODO: try LTO, `--link-time-opt --nvvmpath /opt/cuda/nvvm`.
145+
# fails with `Ignoring -lto option because no LTO objects found`
146+
if needs_cudadevrt
147+
nvlink_output = tempname(cleanup=false) * ".cubin"
148+
append!(nvlink_opts, [
149+
"--verbose", "--extra-warnings",
150+
"--arch", arch,
151+
"--library-path", dirname(libcudadevrt),
152+
"--library", "cudadevrt",
153+
"--output-file", nvlink_output,
154+
ptxas_output
155+
])
156+
proc, log = run_and_collect(`$(nvlink()) $nvlink_opts`)
157+
log = strip(log)
158+
if !success(proc)
159+
reason = proc.termsignal > 0 ? "nvlink received signal $(proc.termsignal)" :
160+
"nvlink exited with code $(proc.exitcode)"
161+
msg = "Failed to link PTX code ($reason)"
162+
msg *= "\nInvocation arguments: $(join(nvlink_opts, ' '))"
163+
if !isempty(log)
164+
msg *= "\n" * log
165+
end
166+
msg *= "\nIf you think this is a bug, please file an issue and attach $(ptxas_output)"
167+
error(msg)
168+
elseif !isempty(log)
169+
@debug "PTX linker info log:\n" * log
170+
end
171+
rm(ptxas_output)
172+
173+
image = read(nvlink_output)
174+
rm(nvlink_output)
175+
else
176+
image = read(ptxas_output)
177+
rm(ptxas_output)
178+
end
179+
=#
180+
return (image, entry=LLVM.name(meta.entry))
181+
end
182+
183+
# link into an executable kernel
184+
function link(job, compiled)
185+
# load as an executable kernel object
186+
return compiled
187+
end
188+
189+
struct LLVMFunc{F,tt}
190+
f::F
191+
mod::String
192+
end
193+
194+
function (func::LLVMFunc{F,tt})(args...) where{F, tt}
195+
196+
end
197+
16198
function recufunction(f::F, tt::TT=Tuple{}; kwargs...) where {F,TT}
17199
cuda = CUDA.active_state()
200+
@show f, tt
201+
flush(stdout)
18202

19203
Base.@lock CUDA.cufunction_lock begin
20204
# compile the function
21205
cache = CUDA.compiler_cache(cuda.context)
22206
source = CUDA.methodinstance(F, tt)
23207
config = CUDA.compiler_config(cuda.device; kwargs...)::CUDA.CUDACompilerConfig
24-
fun = CUDA.GPUCompiler.cached_compilation(cache, source, config, CUDA.compile, CUDA.link)
208+
fun = CUDA.GPUCompiler.cached_compilation(cache, source, config, compile, link)
25209

26210
@show fun
27-
@show fun.mod
211+
println(string(fun))
212+
#@show fun.mod
28213
# create a callable object that captures the function instance. we don't need to think
29214
# about world age here, as GPUCompiler already does and will return a different object
30-
key = (objectid(source), hash(fun), f)
215+
key = (objectid(source))
31216
kernel = get(_kernel_instances, key, nothing)
32217
if kernel === nothing
33-
# create the kernel state object
34-
state = CUDA.KernelState(create_exceptions!(fun.mod), UInt32(0))
35-
36-
kernel = CUDA.HostKernel{F,tt}(f, fun, state)
218+
kernel = LLVMFunc{F,tt}(f, fun)
37219
_kernel_instances[key] = kernel
38220
end
39-
return kernel::CUDA.HostKernel{F,tt}
221+
return kernel::LLVMFunc{F,tt}
40222
end
41223
end
42224

0 commit comments

Comments
 (0)