Skip to content

Commit b64023d

Browse files
author
William Moses
committed
wqtmp
1 parent b6d3169 commit b64023d

File tree

1 file changed

+22
-10
lines changed

1 file changed

+22
-10
lines changed

ext/ReactantCUDAExt.jl

Lines changed: 22 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -16,15 +16,19 @@ end
1616
const _kernel_instances = Dict{Any, Any}()
1717

1818

19-
2019
# compile to executable machine code
2120
function compile(job)
2221
# lower to PTX
2322
# TODO: on 1.9, this actually creates a context. cache those.
24-
modstr = JuliaContext() do ctx
25-
mod, meta = GPUCompiler.compile(:llvm, job)
23+
modstr = CUDA.GPUCompiler.JuliaContext() do ctx
24+
mod, meta = CUDA.GPUCompiler.compile(:llvm, job)
2625
string(mod)
2726
end
27+
println(string(modstr))
28+
@show job
29+
@show job.params
30+
@show job.source
31+
kernel = LLVMFunc{F,tt}(f, modstr)
2832
return modstr
2933
#=
3034
# check if we'll need the device runtime
@@ -187,12 +191,23 @@ function link(job, compiled)
187191
end
188192

189193
struct LLVMFunc{F,tt}
190-
f::F
191-
mod::String
194+
f::F
195+
mod::String
192196
end
193197

194198
function (func::LLVMFunc{F,tt})(args...) where{F, tt}
195-
199+
200+
end
201+
202+
# cache of compilation caches, per context
203+
const _compiler_caches = Dict{MLIR.IR.Context, Dict{Any, LLVMFunc}}();
204+
function compiler_cache(ctx::MLIR.IR.Context)
205+
cache = get(_compiler_caches, ctx, nothing)
206+
if cache === nothing
207+
cache = Dict{Any, LLVMFunc}()
208+
_compiler_caches[ctx] = cache
209+
end
210+
return cache
196211
end
197212

198213
function recufunction(f::F, tt::TT=Tuple{}; kwargs...) where {F,TT}
@@ -202,20 +217,17 @@ function recufunction(f::F, tt::TT=Tuple{}; kwargs...) where {F,TT}
202217

203218
Base.@lock CUDA.cufunction_lock begin
204219
# compile the function
205-
cache = CUDA.compiler_cache(cuda.context)
220+
cache = compiler_cache(MLIR.IR.context())
206221
source = CUDA.methodinstance(F, tt)
207222
config = CUDA.compiler_config(cuda.device; kwargs...)::CUDA.CUDACompilerConfig
208223
fun = CUDA.GPUCompiler.cached_compilation(cache, source, config, compile, link)
209224

210-
@show fun
211-
println(string(fun))
212225
#@show fun.mod
213226
# create a callable object that captures the function instance. we don't need to think
214227
# about world age here, as GPUCompiler already does and will return a different object
215228
key = (objectid(source))
216229
kernel = get(_kernel_instances, key, nothing)
217230
if kernel === nothing
218-
kernel = LLVMFunc{F,tt}(f, fun)
219231
_kernel_instances[key] = kernel
220232
end
221233
return kernel::LLVMFunc{F,tt}

0 commit comments

Comments
 (0)