|
| 1 | +module ReactantCUDAExt |
| 2 | + |
| 3 | +using CUDA |
| 4 | +using Reactant: |
| 5 | + Reactant, TracedRArray, AnyTracedRArray, materialize_traced_array, MLIR, TracedRNumber |
| 6 | +using ReactantCore: @trace |
| 7 | + |
| 8 | + |
| 9 | +const _kernel_instances = Dict{Any, Any}() |
| 10 | + |
| 11 | +function recufunction(f::F, tt::TT=Tuple{}; kwargs...) where {F,TT} |
| 12 | + cuda = CUDA.active_state() |
| 13 | + |
| 14 | + F2 = Reactant.traced_type(F, (), Val(Reactant.TracedToConcrete)) |
| 15 | + tt2 = Reactant.traced_type(tt, (), Val(Reactant.TracedToConcrete)) |
| 16 | + |
| 17 | + |
| 18 | + Base.@lock CUDA.cufunction_lock begin |
| 19 | + # compile the function |
| 20 | + cache = CUDA.compiler_cache(cuda.context) |
| 21 | + source = CUDA.methodinstance(F2, tt2) |
| 22 | + config = CUDA.compiler_config(cuda.device; kwargs...)::CUDA.CUDACompilerConfig |
| 23 | + fun = CUDA.GPUCompiler.cached_compilation(cache, source, config, CUDA.compile, CUDA.link) |
| 24 | + |
| 25 | + @show fun |
| 26 | + @show fun.mod |
| 27 | + # create a callable object that captures the function instance. we don't need to think |
| 28 | + # about world age here, as GPUCompiler already does and will return a different object |
| 29 | + key = (objectid(source), hash(fun), f) |
| 30 | + kernel = get(_kernel_instances, key, nothing) |
| 31 | + if kernel === nothing |
| 32 | + # create the kernel state object |
| 33 | + state = CUDA.KernelState(create_exceptions!(fun.mod), UInt32(0)) |
| 34 | + |
| 35 | + kernel = CUDA.HostKernel{F,tt}(f, fun, state) |
| 36 | + _kernel_instances[key] = kernel |
| 37 | + end |
| 38 | + return kernel::CUDA.HostKernel{F,tt} |
| 39 | + end |
| 40 | +end |
| 41 | + |
| 42 | +const CC = Core.Compiler |
| 43 | + |
| 44 | +import Core.Compiler: |
| 45 | + AbstractInterpreter, |
| 46 | + abstract_call, |
| 47 | + abstract_call_known, |
| 48 | + ArgInfo, |
| 49 | + StmtInfo, |
| 50 | + AbsIntState, |
| 51 | + get_max_methods, |
| 52 | + CallMeta, |
| 53 | + Effects, |
| 54 | + NoCallInfo, |
| 55 | + widenconst, |
| 56 | + mapany, |
| 57 | + MethodResultPure |
| 58 | + |
| 59 | + |
| 60 | +function Reactant.set_reactant_abi( |
| 61 | + interp, |
| 62 | + f::typeof(CUDA.cufunction), |
| 63 | + arginfo::ArgInfo, |
| 64 | + si::StmtInfo, |
| 65 | + sv::AbsIntState, |
| 66 | + max_methods::Int=get_max_methods(interp, f, sv), |
| 67 | +) |
| 68 | + (; fargs, argtypes) = arginfo |
| 69 | + |
| 70 | + arginfo2 = ArgInfo( |
| 71 | + if fargs isa Nothing |
| 72 | + nothing |
| 73 | + else |
| 74 | + [:($(recufunction)), fargs[2:end]...] |
| 75 | + end, |
| 76 | + [Core.Const(recufunction), argtypes[2:end]...], |
| 77 | + ) |
| 78 | + return abstract_call_known(interp, recufunction, arginfo2, si, sv, max_methods) |
| 79 | +end |
| 80 | + |
| 81 | +end # module ReactantCUDAExt |
0 commit comments