1616const _kernel_instances = Dict {Any, Any} ()
1717
1818
19-
2019# compile to executable machine code
2120function compile (job)
2221 # lower to PTX
2322 # TODO : on 1.9, this actually creates a context. cache those.
24- modstr = JuliaContext () do ctx
25- mod, meta = GPUCompiler. compile (:llvm , job)
23+ modstr = CUDA . GPUCompiler . JuliaContext () do ctx
24+ mod, meta = CUDA . GPUCompiler. compile (:llvm , job)
2625 string (mod)
2726 end
27+ println (string (modstr))
28+ @show job
29+ @show job. params
30+ @show job. source
31+ kernel = LLVMFunc {F,tt} (f, modstr)
2832 return modstr
2933#=
3034 # check if we'll need the device runtime
@@ -187,12 +191,23 @@ function link(job, compiled)
187191end
188192
189193struct LLVMFunc{F,tt}
190- f:: F
191- mod:: String
194+ f:: F
195+ mod:: String
192196end
193197
194198function (func:: LLVMFunc{F,tt} )(args... ) where {F, tt}
195-
199+
200+ end
201+
202+ # cache of compilation caches, per context
203+ const _compiler_caches = Dict {MLIR.IR.Context, Dict{Any, LLVMFunc}} ();
204+ function compiler_cache (ctx:: MLIR.IR.Context )
205+ cache = get (_compiler_caches, ctx, nothing )
206+ if cache === nothing
207+ cache = Dict {Any, LLVMFunc} ()
208+ _compiler_caches[ctx] = cache
209+ end
210+ return cache
196211end
197212
198213function recufunction (f:: F , tt:: TT = Tuple{}; kwargs... ) where {F,TT}
@@ -202,20 +217,17 @@ function recufunction(f::F, tt::TT=Tuple{}; kwargs...) where {F,TT}
202217
203218 Base. @lock CUDA. cufunction_lock begin
204219 # compile the function
205- cache = CUDA . compiler_cache (cuda . context)
220+ cache = compiler_cache (MLIR . IR . context () )
206221 source = CUDA. methodinstance (F, tt)
207222 config = CUDA. compiler_config (cuda. device; kwargs... ):: CUDA.CUDACompilerConfig
208223 fun = CUDA. GPUCompiler. cached_compilation (cache, source, config, compile, link)
209224
210- @show fun
211- println (string (fun))
212225 # @show fun.mod
213226 # create a callable object that captures the function instance. we don't need to think
214227 # about world age here, as GPUCompiler already does and will return a different object
215228 key = (objectid (source))
216229 kernel = get (_kernel_instances, key, nothing )
217230 if kernel === nothing
218- kernel = LLVMFunc {F,tt} (f, fun)
219231 _kernel_instances[key] = kernel
220232 end
221233 return kernel:: LLVMFunc{F,tt}
0 commit comments