@@ -221,7 +221,15 @@ function compile(job)
221221 modstr, image, entry = CUDA. GPUCompiler. JuliaContext () do ctx
222222 asm, meta = CUDA. GPUCompiler. compile (:asm , job)
223223 mod = meta. ir
224+
224225 modstr = string (mod)
226+
227+ # This is a bit weird since we're taking a module from julia's llvm into reactant's llvm version
228+ # it is probably safer to reparse a string using the right llvm module api, so we will do that.
229+
230+ mmod = MLIR. IR. Module (@ccall MLIR. API. mlir_c. ConvertLLVMToMLIR (mod:: CUDA.LLVM.API.LLVMModuleRef , MLIR. IR. context ():: MLIR.API.MlirContext ):: MLIR.API.MlirModule )
231+ @show mmod
232+
225233 # check if we'll need the device runtime
226234 undefined_fs = filter (collect (CUDA. LLVM. functions (meta. ir))) do f
227235 CUDA. LLVM. isdeclaration (f) && ! CUDA. LLVM. isintrinsic (f)
@@ -424,7 +432,8 @@ function (func::LLVMFunc{F,tt})(args...; convert=Val(false), blocks::CuDim=1, th
424432 end
425433
426434 output_operand_aliases= MLIR. IR. Attribute (aliases)
427- call = MLIR. Dialects. stablehlo. custom_call (mlir_args; result_0= restys, call_target_name= " reactant_gpu_call" , output_operand_aliases)
435+ call = MLIR. Dialects. stablehlo. custom_call (mlir_args; result_0= restys, call_target_name= " reactant_gpu_call" , output_operand_aliases, backend_config= MLIR. IR. Attribute (" configstr" ))
436+ # call = MLIR.Dialects.stablehlo.custom_call(mlir_args; result_0=restys, call_target_name="reactant_gpu_call", output_operand_aliases, backend_config=MLIR.IR.Attribute(func.mod))
428437 for (i, res) in enumerate (rarrays)
429438 res. mlir_data = transpose_val (MLIR. IR. result (call, i))
430439 end
@@ -459,4 +468,8 @@ Reactant.@reactant_override function CUDA.cufunction(f::F, tt::TT=Tuple{}; kwarg
459468 res
460469end
461470
471+ function __init__ ()
472+
473+ end
474+
462475end # module ReactantCUDAExt
0 commit comments