Skip to content

Commit 8e01310

Browse files
author
William Moses
committed
conversion
1 parent 1d42379 commit 8e01310

File tree

2 files changed

+22
-1
lines changed

2 files changed

+22
-1
lines changed

deps/ReactantExtra/API.cpp

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -376,6 +376,14 @@ extern "C" MlirModule ConvertLLVMToMLIR(LLVMModuleRef lmod, MlirContext cctx) {
376376
return wrap(res);
377377
}
378378

379+
extern "C" MlirModule ConvertLLVMStrToMLIR(const char* lmod, MlirContext cctx) {
380+
LLVMContext Context;
381+
auto llvmModule = llvm::parseIR(llvm::MemoryBufferRef(lmod, "conversion"), Context);
382+
mlir::MLIRContext &context = *unwrap(cctx);
383+
auto res = mlir::translateLLVMIRToModule(std::move(llvmModule), &context, /*emitExpensiveWarnings*/false, /*dropDICompositeElements*/false).release();
384+
return wrap(res);
385+
}
386+
379387

380388
/* Note that this */
381389
extern "C" xla::PjRtLoadedExecutable* ClientCompile(PjRtClient * client, MlirModule cmod) {

ext/ReactantCUDAExt.jl

Lines changed: 14 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -221,7 +221,15 @@ function compile(job)
221221
modstr, image, entry = CUDA.GPUCompiler.JuliaContext() do ctx
222222
asm, meta = CUDA.GPUCompiler.compile(:asm, job)
223223
mod = meta.ir
224+
224225
modstr = string(mod)
226+
227+
# This is a bit weird since we're taking a module from julia's llvm into reactant's llvm version
228+
# it is probably safer to reparse a string using the right llvm module api, so we will do that.
229+
230+
mmod = MLIR.IR.Module(@ccall MLIR.API.mlir_c.ConvertLLVMToMLIR(mod::CUDA.LLVM.API.LLVMModuleRef, MLIR.IR.context()::MLIR.API.MlirContext)::MLIR.API.MlirModule)
231+
@show mmod
232+
225233
# check if we'll need the device runtime
226234
undefined_fs = filter(collect(CUDA.LLVM.functions(meta.ir))) do f
227235
CUDA.LLVM.isdeclaration(f) && !CUDA.LLVM.isintrinsic(f)
@@ -424,7 +432,8 @@ function (func::LLVMFunc{F,tt})(args...; convert=Val(false), blocks::CuDim=1, th
424432
end
425433

426434
output_operand_aliases=MLIR.IR.Attribute(aliases)
427-
call = MLIR.Dialects.stablehlo.custom_call(mlir_args; result_0=restys, call_target_name="reactant_gpu_call", output_operand_aliases)
435+
call = MLIR.Dialects.stablehlo.custom_call(mlir_args; result_0=restys, call_target_name="reactant_gpu_call", output_operand_aliases, backend_config=MLIR.IR.Attribute("configstr"))
436+
# call = MLIR.Dialects.stablehlo.custom_call(mlir_args; result_0=restys, call_target_name="reactant_gpu_call", output_operand_aliases, backend_config=MLIR.IR.Attribute(func.mod))
428437
for (i, res) in enumerate(rarrays)
429438
res.mlir_data = transpose_val(MLIR.IR.result(call, i))
430439
end
@@ -459,4 +468,8 @@ Reactant.@reactant_override function CUDA.cufunction(f::F, tt::TT=Tuple{}; kwarg
459468
res
460469
end
461470

471+
function __init__()
472+
473+
end
474+
462475
end # module ReactantCUDAExt

0 commit comments

Comments
 (0)