Skip to content

Commit 3bfd8a3

Browse files
William Moseswsmoses
authored andcommitted
conversion
1 parent ecdf0cb commit 3bfd8a3

File tree

1 file changed

+14
-1
lines changed

1 file changed

+14
-1
lines changed

ext/ReactantCUDAExt.jl

Lines changed: 14 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -221,7 +221,15 @@ function compile(job)
221221
modstr, image, entry = CUDA.GPUCompiler.JuliaContext() do ctx
222222
asm, meta = CUDA.GPUCompiler.compile(:asm, job)
223223
mod = meta.ir
224+
224225
modstr = string(mod)
226+
227+
# This is a bit weird since we're taking a module from julia's llvm into reactant's llvm version
228+
# it is probably safer to reparse a string using the right llvm module api, so we will do that.
229+
230+
mmod = MLIR.IR.Module(@ccall MLIR.API.mlir_c.ConvertLLVMToMLIR(mod::CUDA.LLVM.API.LLVMModuleRef, MLIR.IR.context()::MLIR.API.MlirContext)::MLIR.API.MlirModule)
231+
@show mmod
232+
225233
# check if we'll need the device runtime
226234
undefined_fs = filter(collect(CUDA.LLVM.functions(meta.ir))) do f
227235
CUDA.LLVM.isdeclaration(f) && !CUDA.LLVM.isintrinsic(f)
@@ -424,7 +432,8 @@ function (func::LLVMFunc{F,tt})(args...; convert=Val(false), blocks::CuDim=1, th
424432
end
425433

426434
output_operand_aliases=MLIR.IR.Attribute(aliases)
427-
call = MLIR.Dialects.stablehlo.custom_call(mlir_args; result_0=restys, call_target_name="reactant_gpu_call", output_operand_aliases)
435+
call = MLIR.Dialects.stablehlo.custom_call(mlir_args; result_0=restys, call_target_name="reactant_gpu_call", output_operand_aliases, backend_config=MLIR.IR.Attribute("configstr"))
436+
# call = MLIR.Dialects.stablehlo.custom_call(mlir_args; result_0=restys, call_target_name="reactant_gpu_call", output_operand_aliases, backend_config=MLIR.IR.Attribute(func.mod))
428437
for (i, res) in enumerate(rarrays)
429438
res.mlir_data = transpose_val(MLIR.IR.result(call, i))
430439
end
@@ -459,4 +468,8 @@ Reactant.@reactant_override function CUDA.cufunction(f::F, tt::TT=Tuple{}; kwarg
459468
res
460469
end
461470

471+
function __init__()
472+
473+
end
474+
462475
end # module ReactantCUDAExt

0 commit comments

Comments
 (0)