@@ -200,17 +200,14 @@ end
200200
201201function Adapt. adapt_storage (:: CUDA.KernelAdaptor , xs:: TracedRArray{T,N} ) where {T,N}
202202 res = CuTracedArray {T,N,CUDA.AS.Global, size(xs)} (Base. reinterpret (Core. LLVMPtr{T,CUDA. AS. Global}, Base. pointer_from_objref (xs)))
203- @show res, xs
204203 return res
205204end
206205
207206const _kernel_instances = Dict {Any, Any} ()
208207
209208struct LLVMFunc{F,tt}
210209 f:: Union{F, Nothing}
211- mod:: String
212- image
213- entry:: String
210+ entry:: MLIR.IR.Operation
214211end
215212
216213
@@ -249,11 +246,13 @@ CleanupKernelStatePass() = LLVM.NewPMModulePass("CleanupKernelStatePass", noop_p
249246
250247# compile to executable machine code
251248function compile (job)
252-
253249 # lower to PTX
254250 # TODO : on 1.9, this actually creates a context. cache those.
255- modstr, image, entry = GPUCompiler. JuliaContext () do ctx
251+ entry = GPUCompiler. JuliaContext () do ctx
256252 mod, meta = GPUCompiler. compile (:llvm , job; optimize= false , cleanup= false , validate= false )
253+
254+ entryname = LLVM. name (meta. entry)
255+
257256 GPUCompiler. optimize_module! (job, mod)
258257 opt_level = 2
259258 tm = GPUCompiler. llvm_machine (job. config. target)
@@ -294,162 +293,15 @@ function compile(job)
294293 # This is a bit weird since we're taking a module from julia's llvm into reactant's llvm version
295294 # it is probably safer to reparse a string using the right llvm module api, so we will do that.
296295
297- println (string (modstr))
298296 mmod = MLIR. IR. Module (@ccall MLIR. API. mlir_c. ConvertLLVMStrToMLIR (modstr:: Cstring , MLIR. IR. context ():: MLIR.API.MlirContext ):: MLIR.API.MlirModule )
299- @show mmod
300-
301- # check if we'll need the device runtime
302- undefined_fs = filter (collect (CUDA. LLVM. functions (meta. ir))) do f
303- CUDA. LLVM. isdeclaration (f) && ! CUDA. LLVM. isintrinsic (f)
304- end
305- intrinsic_fns = [" vprintf" , " malloc" , " free" , " __assertfail" ,
306- " __nvvm_reflect" #= TODO : should have been optimized away =# ]
307- needs_cudadevrt = ! isempty (setdiff (CUDA. LLVM. name .(undefined_fs), intrinsic_fns))
308-
309- # prepare invocations of CUDA compiler tools
310- ptxas_opts = String[]
311- nvlink_opts = String[]
312- # # debug flags
313- if Base. JLOptions (). debug_level == 1
314- push! (ptxas_opts, " --generate-line-info" )
315- elseif Base. JLOptions (). debug_level >= 2
316- push! (ptxas_opts, " --device-debug" )
317- push! (nvlink_opts, " --debug" )
318- end
319- # # relocatable device code
320- if needs_cudadevrt
321- push! (ptxas_opts, " --compile-only" )
322- end
323-
324- ptx = job. config. params. ptx
325- cap = job. config. params. cap
326- arch = " sm_$(cap. major)$(cap. minor) "
327-
328- # validate use of parameter memory
329- argtypes = filter ([CUDA. KernelState, job. source. specTypes. parameters... ]) do dt
330- ! CUDA. isghosttype (dt) && ! Core. Compiler. isconstType (dt)
331- end
332- param_usage = sum (sizeof, argtypes)
333- param_limit = 4096
334- if cap >= v " 7.0" && ptx >= v " 8.1"
335- param_limit = 32764
336- end
337- if param_usage > param_limit
338- msg = """ Kernel invocation uses too much parameter memory.
339- $(Base. format_bytes (param_usage)) exceeds the $(Base. format_bytes (param_limit)) limit imposed by sm_$(cap. major)$(cap. minor) / PTX v$(ptx. major) .$(ptx. minor) ."""
340-
341- try
342- details = " \n\n Relevant parameters:"
343-
344- source_types = job. source. specTypes. parameters
345- source_argnames = Base. method_argnames (job. source. def)
346- while length (source_argnames) < length (source_types)
347- # this is probably due to a trailing vararg; repeat its name
348- push! (source_argnames, source_argnames[end ])
349- end
350-
351- for (i, typ) in enumerate (source_types)
352- if CUDA. isghosttype (typ) || Core. Compiler. isconstType (typ)
353- continue
354- end
355- name = source_argnames[i]
356- details *= " \n [$(i- 1 ) ] $name ::$typ uses $(Base. format_bytes (sizeof (typ))) "
357- end
358- details *= " \n "
359-
360- if cap >= v " 7.0" && ptx < v " 8.1" && param_usage < 32764
361- details *= " \n Note: use a newer CUDA to support more parameters on your device.\n "
362- end
363-
364- msg *= details
365- catch err
366- @error " Failed to analyze kernel parameter usage; please file an issue with a reproducer."
367- end
368- error (msg)
369- end
370-
371- # compile to machine code
372- # NOTE: we use tempname since mktemp doesn't support suffixes, and mktempdir is slow
373- ptx_input = tempname (cleanup= false ) * " .ptx"
374- ptxas_output = tempname (cleanup= false ) * " .cubin"
375- write (ptx_input, asm)
376-
377- # we could use the driver's embedded JIT compiler, but that has several disadvantages:
378- # 1. fixes and improvements are slower to arrive, by using `ptxas` we only need to
379- # upgrade the toolkit to get a newer compiler;
380- # 2. version checking is simpler, we otherwise need to use NVML to query the driver
381- # version, which is hard to correlate to PTX JIT improvements;
382- # 3. if we want to be able to use newer (minor upgrades) of the CUDA toolkit on an
383- # older driver, we should use the newer compiler to ensure compatibility.
384- append! (ptxas_opts, [
385- " --verbose" ,
386- " --gpu-name" , arch,
387- " --output-file" , ptxas_output,
388- ptx_input
389- ])
390- proc, log = CUDA. run_and_collect (` $(CUDA. ptxas ()) $ptxas_opts ` )
391- log = strip (log)
392- if ! success (proc)
393- reason = proc. termsignal > 0 ? " ptxas received signal $(proc. termsignal) " :
394- " ptxas exited with code $(proc. exitcode) "
395- msg = " Failed to compile PTX code ($reason )"
396- msg *= " \n Invocation arguments: $(join (ptxas_opts, ' ' )) "
397- if ! isempty (log)
398- msg *= " \n " * log
399- end
400- msg *= " \n If you think this is a bug, please file an issue and attach $(ptx_input) "
401- if parse (Bool, get (ENV , " BUILDKITE" , " false" ))
402- run (` buildkite-agent artifact upload $(ptx_input) ` )
403- end
404- error (msg)
405- elseif ! isempty (log)
406- @debug " PTX compiler log:\n " * log
407- end
408- rm (ptx_input)
409-
410- # link device libraries, if necessary
411- #
412- # this requires relocatable device code, which prevents certain optimizations and
413- # hurts performance. as such, we only do so when absolutely necessary.
414- # TODO : try LTO, `--link-time-opt --nvvmpath /opt/cuda/nvvm`.
415- # fails with `Ignoring -lto option because no LTO objects found`
416- if needs_cudadevrt
417- nvlink_output = tempname (cleanup= false ) * " .cubin"
418- append! (nvlink_opts, [
419- " --verbose" , " --extra-warnings" ,
420- " --arch" , arch,
421- " --library-path" , dirname (libcudadevrt),
422- " --library" , " cudadevrt" ,
423- " --output-file" , nvlink_output,
424- ptxas_output
425- ])
426- proc, log = run_and_collect (` $(CUDA. nvlink ()) $nvlink_opts ` )
427- log = strip (log)
428- if ! success (proc)
429- reason = proc. termsignal > 0 ? " nvlink received signal $(proc. termsignal) " :
430- " nvlink exited with code $(proc. exitcode) "
431- msg = " Failed to link PTX code ($reason )"
432- msg *= " \n Invocation arguments: $(join (nvlink_opts, ' ' )) "
433- if ! isempty (log)
434- msg *= " \n " * log
435- end
436- msg *= " \n If you think this is a bug, please file an issue and attach $(ptxas_output) "
437- error (msg)
438- elseif ! isempty (log)
439- @debug " PTX linker info log:\n " * log
440- end
441- rm (ptxas_output)
442-
443- image = read (nvlink_output)
444- rm (nvlink_output)
445- else
446- image = read (ptxas_output)
447- rm (ptxas_output)
448- end
449-
450- modstr, image, meta. entry
297+
298+ linkRes = @ccall MLIR. API. mlir_c. LinkInModule (MLIR. IR. mmodule ():: MLIR.API.MlirModule , mmod:: MLIR.API.MlirModule , entryname:: Cstring ):: MLIR.API.MlirOperation
299+
300+ entry = MLIR. IR. Operation (linkRes)
301+
302+ entry
451303 end
452- LLVMFunc {job.source.specTypes.parameters[1],job.source.specTypes} (nothing , modstr, image, CUDA . LLVM . name ( entry) )
304+ LLVMFunc {job.source.specTypes.parameters[1],job.source.specTypes} (nothing , entry)
453305end
454306
455307# link into an executable kernel
467319
468320Reactant. @reactant_override @noinline function (func:: LLVMFunc{F,tt} )(args... ; convert= Val (false ), blocks:: CuDim = 1 , threads:: CuDim = 1 ,
469321 cooperative:: Bool = false , shmem:: Integer = 0 , call_kwargs... ) where {F, tt}
470- @show args
471322 @show call_kwargs
472323
473324 blockdim = CUDA. CuDim3 (blocks)
@@ -478,13 +329,11 @@ Reactant.@reactant_override @noinline function (func::LLVMFunc{F,tt})(args...; c
478329 aliases = MLIR. IR. Attribute[]
479330 rarrays = TracedRArray[]
480331 for (i, a) in enumerate (args)
481- @show a
482332 @assert a isa CuTracedArray
483333 ta = Base. unsafe_pointer_to_objref (Base. reinterpret (Ptr{Cvoid}, a. ptr)):: TracedRArray
484334 push! (rarrays, ta)
485335 arg = ta. mlir_data
486336 arg = transpose_val (arg)
487- @show arg
488337 push! (restys, MLIR. IR. type (arg))
489338 push! (mlir_args, arg)
490339 push! (aliases,
@@ -500,11 +349,19 @@ Reactant.@reactant_override @noinline function (func::LLVMFunc{F,tt})(args...; c
500349 end
501350
502351 output_operand_aliases= MLIR. IR. Attribute (aliases)
503- call = MLIR. Dialects. stablehlo. custom_call (mlir_args; result_0= restys, call_target_name= " reactant_gpu_call" , output_operand_aliases, backend_config= MLIR. IR. Attribute (" configstr" ))
352+
353+ fname = Reactant. TracedUtils. get_attribute_by_name (func. entry, " sym_name" )
354+ # Force public for now while we don't have real users
355+ MLIR. IR. rmattr! (func. entry, " sym_visibility" )
356+
357+ call = MLIR. Dialects. stablehlo. custom_call (mlir_args; result_0= restys, call_target_name= " reactant_gpu_call" , output_operand_aliases, backend_config= MLIR. IR. Attribute (fname))
504358 # call = MLIR.Dialects.stablehlo.custom_call(mlir_args; result_0=restys, call_target_name="reactant_gpu_call", output_operand_aliases, backend_config=MLIR.IR.Attribute(func.mod))
505359 for (i, res) in enumerate (rarrays)
506360 res. mlir_data = transpose_val (MLIR. IR. result (call, i))
507361 end
362+
363+ @show blockdim
364+ @show threaddim
508365 # CUDA.cuLaunchKernel(f,
509366 # blockdim.x, blockdim.y, blockdim.z,
510367 # threaddim.x, threaddim.y, threaddim.z,
@@ -523,12 +380,10 @@ function compiler_cache(ctx::MLIR.IR.Context)
523380end
524381
525382Reactant. @reactant_override @noinline function CUDA. cufunction (f:: F , tt:: TT = Tuple{}; kwargs... ) where {F,TT}
526- @show " recufunction" , f, tt
527383 res = Base. @lock CUDA. cufunction_lock begin
528384 # compile the function
529385 cache = compiler_cache (MLIR. IR. context ())
530386 source = CUDA. methodinstance (F, tt)
531-
532387 # cuda = CUDA.active_state()
533388 device = nothing # cuda.device
534389 # config = CUDA.compiler_config(device; kwargs...)::CUDA.CUDACompilerConfig
@@ -543,7 +398,6 @@ Reactant.@reactant_override @noinline function CUDA.cufunction(f::F, tt::TT=Tupl
543398 config = CUDA. CompilerConfig (CUDA. PTXCompilerTarget (; cap= llvm_cap, ptx= llvm_ptx, debuginfo), CUDA. CUDACompilerParams (; cap= cuda_cap, ptx= cuda_ptx); kernel, name, always_inline)
544399 CUDA. GPUCompiler. cached_compilation (cache, source, config, compile, link)
545400 end
546- @show res
547401 res
548402end
549403
0 commit comments