Skip to content

Commit 7a113cf

Browse files
author
William Moses
committed
more files
1 parent 7ded1e4 commit 7a113cf

File tree

1 file changed

+81
-0
lines changed

1 file changed

+81
-0
lines changed

ext/ReactantCUDAExt.jl

Lines changed: 81 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,81 @@
1+
module ReactantCUDAExt
2+
3+
using CUDA
4+
using Reactant:
5+
Reactant, TracedRArray, AnyTracedRArray, materialize_traced_array, MLIR, TracedRNumber
6+
using ReactantCore: @trace
7+
8+
9+
const _kernel_instances = Dict{Any, Any}()
10+
11+
function recufunction(f::F, tt::TT=Tuple{}; kwargs...) where {F,TT}
12+
cuda = CUDA.active_state()
13+
14+
F2 = Reactant.traced_type(F, (), Val(Reactant.TracedToConcrete))
15+
tt2 = Reactant.traced_type(tt, (), Val(Reactant.TracedToConcrete))
16+
17+
18+
Base.@lock CUDA.cufunction_lock begin
19+
# compile the function
20+
cache = CUDA.compiler_cache(cuda.context)
21+
source = CUDA.methodinstance(F2, tt2)
22+
config = CUDA.compiler_config(cuda.device; kwargs...)::CUDA.CUDACompilerConfig
23+
fun = CUDA.GPUCompiler.cached_compilation(cache, source, config, CUDA.compile, CUDA.link)
24+
25+
@show fun
26+
@show fun.mod
27+
# create a callable object that captures the function instance. we don't need to think
28+
# about world age here, as GPUCompiler already does and will return a different object
29+
key = (objectid(source), hash(fun), f)
30+
kernel = get(_kernel_instances, key, nothing)
31+
if kernel === nothing
32+
# create the kernel state object
33+
state = CUDA.KernelState(create_exceptions!(fun.mod), UInt32(0))
34+
35+
kernel = CUDA.HostKernel{F,tt}(f, fun, state)
36+
_kernel_instances[key] = kernel
37+
end
38+
return kernel::CUDA.HostKernel{F,tt}
39+
end
40+
end
41+
42+
const CC = Core.Compiler
43+
44+
import Core.Compiler:
45+
AbstractInterpreter,
46+
abstract_call,
47+
abstract_call_known,
48+
ArgInfo,
49+
StmtInfo,
50+
AbsIntState,
51+
get_max_methods,
52+
CallMeta,
53+
Effects,
54+
NoCallInfo,
55+
widenconst,
56+
mapany,
57+
MethodResultPure
58+
59+
60+
function Reactant.set_reactant_abi(
61+
interp,
62+
f::typeof(CUDA.cufunction),
63+
arginfo::ArgInfo,
64+
si::StmtInfo,
65+
sv::AbsIntState,
66+
max_methods::Int=get_max_methods(interp, f, sv),
67+
)
68+
(; fargs, argtypes) = arginfo
69+
70+
arginfo2 = ArgInfo(
71+
if fargs isa Nothing
72+
nothing
73+
else
74+
[:($(recufunction)), fargs[2:end]...]
75+
end,
76+
[Core.Const(recufunction), argtypes[2:end]...],
77+
)
78+
return abstract_call_known(interp, recufunction, arginfo2, si, sv, max_methods)
79+
end
80+
81+
end # module ReactantCUDAExt

0 commit comments

Comments
 (0)