@@ -8,35 +8,217 @@ using ReactantCore: @trace
88using Adapt
99
1010function Adapt. adapt_storage (:: CUDA.KernelAdaptor , xs:: TracedRArray{T,N} ) where {T,N}
11- CuDeviceArray {T,N,CUDA.AS.Global} (pointer (xs. mlir_data. value), size (xs))
11+ res = CuDeviceArray {T,N,CUDA.AS.Global} (Base. reinterpret (Core. LLVMPtr{T,CUDA. AS. Global}, xs. mlir_data. value. ptr), size (xs))
12+ @show res, xs
13+ return res
1214end
1315
1416const _kernel_instances = Dict {Any, Any} ()
1517
18+
19+
20+ # compile to executable machine code
21+ function compile (job)
22+ # lower to PTX
23+ # TODO : on 1.9, this actually creates a context. cache those.
24+ modstr = JuliaContext () do ctx
25+ mod, meta = GPUCompiler. compile (:llvm , job)
26+ string (mod)
27+ end
28+ return modstr
29+ #=
30+ # check if we'll need the device runtime
31+ undefined_fs = filter(collect(functions(meta.ir))) do f
32+ isdeclaration(f) && !LLVM.isintrinsic(f)
33+ end
34+ intrinsic_fns = ["vprintf", "malloc", "free", "__assertfail",
35+ "__nvvm_reflect" #= TODO : should have been optimized away =# ]
36+ needs_cudadevrt = !isempty(setdiff(LLVM.name.(undefined_fs), intrinsic_fns))
37+
38+ # prepare invocations of CUDA compiler tools
39+ ptxas_opts = String[]
40+ nvlink_opts = String[]
41+ ## debug flags
42+ if Base.JLOptions().debug_level == 1
43+ push!(ptxas_opts, "--generate-line-info")
44+ elseif Base.JLOptions().debug_level >= 2
45+ push!(ptxas_opts, "--device-debug")
46+ push!(nvlink_opts, "--debug")
47+ end
48+ ## relocatable device code
49+ if needs_cudadevrt
50+ push!(ptxas_opts, "--compile-only")
51+ end
52+
53+ ptx = job.config.params.ptx
54+ cap = job.config.params.cap
55+ arch = "sm_$(cap.major)$(cap.minor)"
56+
57+ # validate use of parameter memory
58+ argtypes = filter([KernelState, job.source.specTypes.parameters...]) do dt
59+ !isghosttype(dt) && !Core.Compiler.isconstType(dt)
60+ end
61+ param_usage = sum(sizeof, argtypes)
62+ param_limit = 4096
63+ if cap >= v"7.0" && ptx >= v"8.1"
64+ param_limit = 32764
65+ end
66+ if param_usage > param_limit
67+ msg = """Kernel invocation uses too much parameter memory.
68+ $(Base.format_bytes(param_usage)) exceeds the $(Base.format_bytes(param_limit)) limit imposed by sm_$(cap.major)$(cap.minor) / PTX v$(ptx.major).$(ptx.minor)."""
69+
70+ try
71+ details = "\n\nRelevant parameters:"
72+
73+ source_types = job.source.specTypes.parameters
74+ source_argnames = Base.method_argnames(job.source.def)
75+ while length(source_argnames) < length(source_types)
76+ # this is probably due to a trailing vararg; repeat its name
77+ push!(source_argnames, source_argnames[end])
78+ end
79+
80+ for (i, typ) in enumerate(source_types)
81+ if isghosttype(typ) || Core.Compiler.isconstType(typ)
82+ continue
83+ end
84+ name = source_argnames[i]
85+ details *= "\n [$(i-1)] $name::$typ uses $(Base.format_bytes(sizeof(typ)))"
86+ end
87+ details *= "\n"
88+
89+ if cap >= v"7.0" && ptx < v"8.1" && param_usage < 32764
90+ details *= "\nNote: use a newer CUDA to support more parameters on your device.\n"
91+ end
92+
93+ msg *= details
94+ catch err
95+ @error "Failed to analyze kernel parameter usage; please file an issue with a reproducer."
96+ end
97+ error(msg)
98+ end
99+
100+ # compile to machine code
101+ # NOTE: we use tempname since mktemp doesn't support suffixes, and mktempdir is slow
102+ ptx_input = tempname(cleanup=false) * ".ptx"
103+ ptxas_output = tempname(cleanup=false) * ".cubin"
104+ write(ptx_input, asm)
105+
106+ # we could use the driver's embedded JIT compiler, but that has several disadvantages:
107+ # 1. fixes and improvements are slower to arrive, by using `ptxas` we only need to
108+ # upgrade the toolkit to get a newer compiler;
109+ # 2. version checking is simpler, we otherwise need to use NVML to query the driver
110+ # version, which is hard to correlate to PTX JIT improvements;
111+ # 3. if we want to be able to use newer (minor upgrades) of the CUDA toolkit on an
112+ # older driver, we should use the newer compiler to ensure compatibility.
113+ append!(ptxas_opts, [
114+ "--verbose",
115+ "--gpu-name", arch,
116+ "--output-file", ptxas_output,
117+ ptx_input
118+ ])
119+ proc, log = run_and_collect(`$(ptxas()) $ptxas_opts`)
120+ log = strip(log)
121+ if !success(proc)
122+ reason = proc.termsignal > 0 ? "ptxas received signal $(proc.termsignal)" :
123+ "ptxas exited with code $(proc.exitcode)"
124+ msg = "Failed to compile PTX code ($reason)"
125+ msg *= "\nInvocation arguments: $(join(ptxas_opts, ' '))"
126+ if !isempty(log)
127+ msg *= "\n" * log
128+ end
129+ msg *= "\nIf you think this is a bug, please file an issue and attach $(ptx_input)"
130+ if parse(Bool, get(ENV, "BUILDKITE", "false"))
131+ run(`buildkite-agent artifact upload $(ptx_input)`)
132+ end
133+ error(msg)
134+ elseif !isempty(log)
135+ @debug "PTX compiler log:\n" * log
136+ end
137+ rm(ptx_input)
138+ =#
139+ #=
140+ # link device libraries, if necessary
141+ #
142+ # this requires relocatable device code, which prevents certain optimizations and
143+ # hurts performance. as such, we only do so when absolutely necessary.
144+ # TODO : try LTO, `--link-time-opt --nvvmpath /opt/cuda/nvvm`.
145+ # fails with `Ignoring -lto option because no LTO objects found`
146+ if needs_cudadevrt
147+ nvlink_output = tempname(cleanup=false) * ".cubin"
148+ append!(nvlink_opts, [
149+ "--verbose", "--extra-warnings",
150+ "--arch", arch,
151+ "--library-path", dirname(libcudadevrt),
152+ "--library", "cudadevrt",
153+ "--output-file", nvlink_output,
154+ ptxas_output
155+ ])
156+ proc, log = run_and_collect(`$(nvlink()) $nvlink_opts`)
157+ log = strip(log)
158+ if !success(proc)
159+ reason = proc.termsignal > 0 ? "nvlink received signal $(proc.termsignal)" :
160+ "nvlink exited with code $(proc.exitcode)"
161+ msg = "Failed to link PTX code ($reason)"
162+ msg *= "\nInvocation arguments: $(join(nvlink_opts, ' '))"
163+ if !isempty(log)
164+ msg *= "\n" * log
165+ end
166+ msg *= "\nIf you think this is a bug, please file an issue and attach $(ptxas_output)"
167+ error(msg)
168+ elseif !isempty(log)
169+ @debug "PTX linker info log:\n" * log
170+ end
171+ rm(ptxas_output)
172+
173+ image = read(nvlink_output)
174+ rm(nvlink_output)
175+ else
176+ image = read(ptxas_output)
177+ rm(ptxas_output)
178+ end
179+ =#
180+ return (image, entry= LLVM. name (meta. entry))
181+ end
182+
183+ # link into an executable kernel
184+ function link (job, compiled)
185+ # load as an executable kernel object
186+ return compiled
187+ end
188+
189+ struct LLVMFunc{F,tt}
190+ f:: F
191+ mod:: String
192+ end
193+
194+ function (func:: LLVMFunc{F,tt} )(args... ) where {F, tt}
195+
196+ end
197+
16198function recufunction (f:: F , tt:: TT = Tuple{}; kwargs... ) where {F,TT}
17199 cuda = CUDA. active_state ()
200+ @show f, tt
201+ flush (stdout )
18202
19203 Base. @lock CUDA. cufunction_lock begin
20204 # compile the function
21205 cache = CUDA. compiler_cache (cuda. context)
22206 source = CUDA. methodinstance (F, tt)
23207 config = CUDA. compiler_config (cuda. device; kwargs... ):: CUDA.CUDACompilerConfig
24- fun = CUDA. GPUCompiler. cached_compilation (cache, source, config, CUDA . compile, CUDA . link)
208+ fun = CUDA. GPUCompiler. cached_compilation (cache, source, config, compile, link)
25209
26210 @show fun
27- @show fun. mod
211+ println (string (fun))
212+ # @show fun.mod
28213 # create a callable object that captures the function instance. we don't need to think
29214 # about world age here, as GPUCompiler already does and will return a different object
30- key = (objectid (source), hash (fun), f )
215+ key = (objectid (source))
31216 kernel = get (_kernel_instances, key, nothing )
32217 if kernel === nothing
33- # create the kernel state object
34- state = CUDA. KernelState (create_exceptions! (fun. mod), UInt32 (0 ))
35-
36- kernel = CUDA. HostKernel {F,tt} (f, fun, state)
218+ kernel = LLVMFunc {F,tt} (f, fun)
37219 _kernel_instances[key] = kernel
38220 end
39- return kernel:: CUDA.HostKernel {F,tt}
221+ return kernel:: LLVMFunc {F,tt}
40222 end
41223end
42224
0 commit comments