Skip to content

Commit d0e5195

Browse files
committed
host and device IR
1 parent 0c61f5d commit d0e5195

File tree

4 files changed

+111
-168
lines changed

4 files changed

+111
-168
lines changed

deps/ReactantExtra/API.cpp

Lines changed: 84 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -470,6 +470,10 @@ extern "C" void RegisterDialects(MlirContext cctx) {
470470
context.loadDialect<mlir::stablehlo::StablehloDialect>();
471471
context.loadDialect<mlir::chlo::ChloDialect>();
472472
}
473+
474+
#include "mlir/Target/LLVMIR/Dialect/NVVM/LLVMIRToNVVMTranslation.h"
475+
#include "mlir/Target/LLVMIR/Dialect/LLVMIR/LLVMIRToLLVMTranslation.h"
476+
#include "mlir/Dialect/LLVMIR/Transforms/InlinerInterfaceImpl.h"
473477
extern "C" void InitializeRegistryAndPasses(MlirDialectRegistry creg) {
474478
mlir::DialectRegistry &registry = *unwrap(creg);
475479

@@ -513,6 +517,11 @@ extern "C" void InitializeRegistryAndPasses(MlirDialectRegistry creg) {
513517
mlir::affine::registerAffinePasses();
514518
mlir::registerReconcileUnrealizedCasts();
515519

520+
mlir::registerLLVMDialectImport(registry);
521+
mlir::registerNVVMDialectImport(registry);
522+
523+
mlir::LLVM::registerInlinerInterface(registry);
524+
516525
/*
517526
registry.addExtension(+[](MLIRContext *ctx, LLVM::LLVMDialect *dialect) {
518527
LLVM::LLVMFunctionType::attachInterface<MemRefInsider>(*ctx);
@@ -540,6 +549,81 @@ extern "C" void InitializeRegistryAndPasses(MlirDialectRegistry creg) {
540549
mlir::enzyme::registerEnzymeJaxTransformExtension(registry);
541550
}
542551

552+
553+
/// Returns an unused symbol in `module` for `oldSymbolName` by trying numeric
554+
/// suffix in `lastUsedID`.
555+
static mlir::StringAttr renameSymbol(llvm::StringRef oldSymName,
556+
unsigned &lastUsedID,
557+
mlir::ModuleOp source,
558+
mlir::ModuleOp target) {
559+
using namespace llvm;
560+
using namespace mlir;
561+
SmallString<64> newSymName(oldSymName);
562+
newSymName.push_back('_');
563+
while (true) {
564+
auto possible = newSymName + Twine(++lastUsedID);
565+
if (!SymbolTable::lookupSymbolIn(source, possible.str()) && !SymbolTable::lookupSymbolIn(target, possible.str())) {
566+
return StringAttr::get(target.getContext(), possible);
567+
}
568+
}
569+
}
570+
571+
572+
/// Checks if a symbol with the same name as `op` already exists in `source`.
573+
/// If so, renames `op` and updates all its references in `target`.
574+
static mlir::LogicalResult
575+
updateSymbolAndAllUses(mlir::SymbolOpInterface op, mlir::ModuleOp source, mlir::ModuleOp target,
576+
unsigned &lastUsedID) {
577+
using namespace llvm;
578+
using namespace mlir;
579+
580+
auto opName = op.getName().str();
581+
582+
if (!SymbolTable::lookupSymbolIn(target, opName)) {
583+
return success();
584+
}
585+
586+
StringAttr newSymName =
587+
renameSymbol(opName, lastUsedID, source, target);
588+
589+
if (failed(SymbolTable::replaceAllSymbolUses(op, newSymName, source)))
590+
return op.emitError("unable to update all symbol uses for ")
591+
<< opName << " to " << newSymName;
592+
593+
SymbolTable::setSymbolName(op, newSymName);
594+
return success();
595+
}
596+
597+
extern "C" MlirOperation LinkInModule(MlirModule prevModC, MlirModule newModC, const char* entryfn) {
598+
auto prevMod = cast<ModuleOp>(*unwrap(prevModC));
599+
auto newMod = cast<ModuleOp>(*unwrap(newModC));
600+
601+
Operation* entryFn = nullptr;
602+
603+
unsigned lastUsedID = 0;
604+
605+
for (auto &op : *newMod.getBody()) {
606+
auto symbolOp = dyn_cast<SymbolOpInterface>(op);
607+
if (!symbolOp)
608+
continue;
609+
610+
StringRef oldSymName = symbolOp.getName();
611+
612+
if (oldSymName == entryfn) {
613+
entryFn = &op;
614+
}
615+
616+
if (failed(updateSymbolAndAllUses(symbolOp, newMod, prevMod,
617+
lastUsedID))) {
618+
assert(0 && "failed to update all uses");
619+
}
620+
SymbolTable::setSymbolVisibility(&op, SymbolTable::Visibility::Private);
621+
}
622+
prevMod.getBody()->getOperations().splice(prevMod.getBody()->getOperations().end(),
623+
newMod.getBody()->getOperations());
624+
return wrap(entryFn);
625+
}
626+
543627
#pragma region xla::ifrt
544628

545629
#pragma region xla::ifrt::Value

deps/ReactantExtra/BUILD

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -416,6 +416,7 @@ cc_library(
416416
"-Wl,-exported_symbol,_BufferToHost",
417417
"-Wl,-exported_symbol,_FreeClient",
418418
"-Wl,-exported_symbol,_ClientCompile",
419+
"-Wl,-exported_symbol,_LinkInModule",
419420
"-Wl,-exported_symbol,_FreeFuture",
420421
"-Wl,-exported_symbol,_FutureIsReady",
421422
"-Wl,-exported_symbol,_FutureAwait",
@@ -451,6 +452,10 @@ cc_library(
451452
"@llvm-project//mlir:TransformDialect",
452453
"@llvm-project//mlir:Transforms",
453454

455+
"@llvm-project//mlir:LLVMIRToLLVMTranslation",
456+
"@llvm-project//mlir:LLVMIRToNVVMTranslation",
457+
"@llvm-project//mlir:LLVMIRTransforms",
458+
454459
"@llvm-project//llvm:IRReader",
455460
"@llvm-project//llvm:Support",
456461
"@llvm-project//llvm:AArch64AsmParser",

ext/ReactantCUDAExt.jl

Lines changed: 21 additions & 167 deletions
Original file line numberDiff line numberDiff line change
@@ -200,17 +200,14 @@ end
200200

201201
function Adapt.adapt_storage(::CUDA.KernelAdaptor, xs::TracedRArray{T,N}) where {T,N}
202202
res = CuTracedArray{T,N,CUDA.AS.Global, size(xs)}(Base.reinterpret(Core.LLVMPtr{T,CUDA.AS.Global}, Base.pointer_from_objref(xs)))
203-
@show res, xs
204203
return res
205204
end
206205

207206
const _kernel_instances = Dict{Any, Any}()
208207

209208
struct LLVMFunc{F,tt}
210209
f::Union{F, Nothing}
211-
mod::String
212-
image
213-
entry::String
210+
entry::MLIR.IR.Operation
214211
end
215212

216213

@@ -249,11 +246,13 @@ CleanupKernelStatePass() = LLVM.NewPMModulePass("CleanupKernelStatePass", noop_p
249246

250247
# compile to executable machine code
251248
function compile(job)
252-
253249
# lower to PTX
254250
# TODO: on 1.9, this actually creates a context. cache those.
255-
modstr, image, entry = GPUCompiler.JuliaContext() do ctx
251+
entry = GPUCompiler.JuliaContext() do ctx
256252
mod, meta = GPUCompiler.compile(:llvm, job; optimize=false, cleanup=false, validate=false)
253+
254+
entryname = LLVM.name(meta.entry)
255+
257256
GPUCompiler.optimize_module!(job, mod)
258257
opt_level = 2
259258
tm = GPUCompiler.llvm_machine(job.config.target)
@@ -294,162 +293,15 @@ function compile(job)
294293
# This is a bit weird since we're taking a module from julia's llvm into reactant's llvm version
295294
# it is probably safer to reparse a string using the right llvm module api, so we will do that.
296295

297-
println(string(modstr))
298296
mmod = MLIR.IR.Module(@ccall MLIR.API.mlir_c.ConvertLLVMStrToMLIR(modstr::Cstring, MLIR.IR.context()::MLIR.API.MlirContext)::MLIR.API.MlirModule)
299-
@show mmod
300-
301-
# check if we'll need the device runtime
302-
undefined_fs = filter(collect(CUDA.LLVM.functions(meta.ir))) do f
303-
CUDA.LLVM.isdeclaration(f) && !CUDA.LLVM.isintrinsic(f)
304-
end
305-
intrinsic_fns = ["vprintf", "malloc", "free", "__assertfail",
306-
"__nvvm_reflect" #= TODO: should have been optimized away =#]
307-
needs_cudadevrt = !isempty(setdiff(CUDA.LLVM.name.(undefined_fs), intrinsic_fns))
308-
309-
# prepare invocations of CUDA compiler tools
310-
ptxas_opts = String[]
311-
nvlink_opts = String[]
312-
## debug flags
313-
if Base.JLOptions().debug_level == 1
314-
push!(ptxas_opts, "--generate-line-info")
315-
elseif Base.JLOptions().debug_level >= 2
316-
push!(ptxas_opts, "--device-debug")
317-
push!(nvlink_opts, "--debug")
318-
end
319-
## relocatable device code
320-
if needs_cudadevrt
321-
push!(ptxas_opts, "--compile-only")
322-
end
323-
324-
ptx = job.config.params.ptx
325-
cap = job.config.params.cap
326-
arch = "sm_$(cap.major)$(cap.minor)"
327-
328-
# validate use of parameter memory
329-
argtypes = filter([CUDA.KernelState, job.source.specTypes.parameters...]) do dt
330-
!CUDA.isghosttype(dt) && !Core.Compiler.isconstType(dt)
331-
end
332-
param_usage = sum(sizeof, argtypes)
333-
param_limit = 4096
334-
if cap >= v"7.0" && ptx >= v"8.1"
335-
param_limit = 32764
336-
end
337-
if param_usage > param_limit
338-
msg = """Kernel invocation uses too much parameter memory.
339-
$(Base.format_bytes(param_usage)) exceeds the $(Base.format_bytes(param_limit)) limit imposed by sm_$(cap.major)$(cap.minor) / PTX v$(ptx.major).$(ptx.minor)."""
340-
341-
try
342-
details = "\n\nRelevant parameters:"
343-
344-
source_types = job.source.specTypes.parameters
345-
source_argnames = Base.method_argnames(job.source.def)
346-
while length(source_argnames) < length(source_types)
347-
# this is probably due to a trailing vararg; repeat its name
348-
push!(source_argnames, source_argnames[end])
349-
end
350-
351-
for (i, typ) in enumerate(source_types)
352-
if CUDA.isghosttype(typ) || Core.Compiler.isconstType(typ)
353-
continue
354-
end
355-
name = source_argnames[i]
356-
details *= "\n [$(i-1)] $name::$typ uses $(Base.format_bytes(sizeof(typ)))"
357-
end
358-
details *= "\n"
359-
360-
if cap >= v"7.0" && ptx < v"8.1" && param_usage < 32764
361-
details *= "\nNote: use a newer CUDA to support more parameters on your device.\n"
362-
end
363-
364-
msg *= details
365-
catch err
366-
@error "Failed to analyze kernel parameter usage; please file an issue with a reproducer."
367-
end
368-
error(msg)
369-
end
370-
371-
# compile to machine code
372-
# NOTE: we use tempname since mktemp doesn't support suffixes, and mktempdir is slow
373-
ptx_input = tempname(cleanup=false) * ".ptx"
374-
ptxas_output = tempname(cleanup=false) * ".cubin"
375-
write(ptx_input, asm)
376-
377-
# we could use the driver's embedded JIT compiler, but that has several disadvantages:
378-
# 1. fixes and improvements are slower to arrive, by using `ptxas` we only need to
379-
# upgrade the toolkit to get a newer compiler;
380-
# 2. version checking is simpler, we otherwise need to use NVML to query the driver
381-
# version, which is hard to correlate to PTX JIT improvements;
382-
# 3. if we want to be able to use newer (minor upgrades) of the CUDA toolkit on an
383-
# older driver, we should use the newer compiler to ensure compatibility.
384-
append!(ptxas_opts, [
385-
"--verbose",
386-
"--gpu-name", arch,
387-
"--output-file", ptxas_output,
388-
ptx_input
389-
])
390-
proc, log = CUDA.run_and_collect(`$(CUDA.ptxas()) $ptxas_opts`)
391-
log = strip(log)
392-
if !success(proc)
393-
reason = proc.termsignal > 0 ? "ptxas received signal $(proc.termsignal)" :
394-
"ptxas exited with code $(proc.exitcode)"
395-
msg = "Failed to compile PTX code ($reason)"
396-
msg *= "\nInvocation arguments: $(join(ptxas_opts, ' '))"
397-
if !isempty(log)
398-
msg *= "\n" * log
399-
end
400-
msg *= "\nIf you think this is a bug, please file an issue and attach $(ptx_input)"
401-
if parse(Bool, get(ENV, "BUILDKITE", "false"))
402-
run(`buildkite-agent artifact upload $(ptx_input)`)
403-
end
404-
error(msg)
405-
elseif !isempty(log)
406-
@debug "PTX compiler log:\n" * log
407-
end
408-
rm(ptx_input)
409-
410-
# link device libraries, if necessary
411-
#
412-
# this requires relocatable device code, which prevents certain optimizations and
413-
# hurts performance. as such, we only do so when absolutely necessary.
414-
# TODO: try LTO, `--link-time-opt --nvvmpath /opt/cuda/nvvm`.
415-
# fails with `Ignoring -lto option because no LTO objects found`
416-
if needs_cudadevrt
417-
nvlink_output = tempname(cleanup=false) * ".cubin"
418-
append!(nvlink_opts, [
419-
"--verbose", "--extra-warnings",
420-
"--arch", arch,
421-
"--library-path", dirname(libcudadevrt),
422-
"--library", "cudadevrt",
423-
"--output-file", nvlink_output,
424-
ptxas_output
425-
])
426-
proc, log = run_and_collect(`$(CUDA.nvlink()) $nvlink_opts`)
427-
log = strip(log)
428-
if !success(proc)
429-
reason = proc.termsignal > 0 ? "nvlink received signal $(proc.termsignal)" :
430-
"nvlink exited with code $(proc.exitcode)"
431-
msg = "Failed to link PTX code ($reason)"
432-
msg *= "\nInvocation arguments: $(join(nvlink_opts, ' '))"
433-
if !isempty(log)
434-
msg *= "\n" * log
435-
end
436-
msg *= "\nIf you think this is a bug, please file an issue and attach $(ptxas_output)"
437-
error(msg)
438-
elseif !isempty(log)
439-
@debug "PTX linker info log:\n" * log
440-
end
441-
rm(ptxas_output)
442-
443-
image = read(nvlink_output)
444-
rm(nvlink_output)
445-
else
446-
image = read(ptxas_output)
447-
rm(ptxas_output)
448-
end
449-
450-
modstr, image, meta.entry
297+
298+
linkRes = @ccall MLIR.API.mlir_c.LinkInModule(MLIR.IR.mmodule()::MLIR.API.MlirModule, mmod::MLIR.API.MlirModule, entryname::Cstring)::MLIR.API.MlirOperation
299+
300+
entry = MLIR.IR.Operation(linkRes)
301+
302+
entry
451303
end
452-
LLVMFunc{job.source.specTypes.parameters[1],job.source.specTypes}(nothing, modstr, image, CUDA.LLVM.name(entry))
304+
LLVMFunc{job.source.specTypes.parameters[1],job.source.specTypes}(nothing, entry)
453305
end
454306

455307
# link into an executable kernel
@@ -467,7 +319,6 @@ end
467319

468320
Reactant.@reactant_override @noinline function (func::LLVMFunc{F,tt})(args...; convert=Val(false), blocks::CuDim=1, threads::CuDim=1,
469321
cooperative::Bool=false, shmem::Integer=0, call_kwargs...) where{F, tt}
470-
@show args
471322
@show call_kwargs
472323

473324
blockdim = CUDA.CuDim3(blocks)
@@ -478,13 +329,11 @@ Reactant.@reactant_override @noinline function (func::LLVMFunc{F,tt})(args...; c
478329
aliases = MLIR.IR.Attribute[]
479330
rarrays = TracedRArray[]
480331
for (i, a) in enumerate(args)
481-
@show a
482332
@assert a isa CuTracedArray
483333
ta = Base.unsafe_pointer_to_objref(Base.reinterpret(Ptr{Cvoid}, a.ptr))::TracedRArray
484334
push!(rarrays, ta)
485335
arg = ta.mlir_data
486336
arg = transpose_val(arg)
487-
@show arg
488337
push!(restys, MLIR.IR.type(arg))
489338
push!(mlir_args, arg)
490339
push!(aliases,
@@ -500,11 +349,19 @@ Reactant.@reactant_override @noinline function (func::LLVMFunc{F,tt})(args...; c
500349
end
501350

502351
output_operand_aliases=MLIR.IR.Attribute(aliases)
503-
call = MLIR.Dialects.stablehlo.custom_call(mlir_args; result_0=restys, call_target_name="reactant_gpu_call", output_operand_aliases, backend_config=MLIR.IR.Attribute("configstr"))
352+
353+
fname = Reactant.TracedUtils.get_attribute_by_name(func.entry, "sym_name")
354+
# Force public for now while we don't have real users
355+
MLIR.IR.rmattr!(func.entry, "sym_visibility")
356+
357+
call = MLIR.Dialects.stablehlo.custom_call(mlir_args; result_0=restys, call_target_name="reactant_gpu_call", output_operand_aliases, backend_config=MLIR.IR.Attribute(fname))
504358
# call = MLIR.Dialects.stablehlo.custom_call(mlir_args; result_0=restys, call_target_name="reactant_gpu_call", output_operand_aliases, backend_config=MLIR.IR.Attribute(func.mod))
505359
for (i, res) in enumerate(rarrays)
506360
res.mlir_data = transpose_val(MLIR.IR.result(call, i))
507361
end
362+
363+
@show blockdim
364+
@show threaddim
508365
#CUDA.cuLaunchKernel(f,
509366
# blockdim.x, blockdim.y, blockdim.z,
510367
# threaddim.x, threaddim.y, threaddim.z,
@@ -523,12 +380,10 @@ function compiler_cache(ctx::MLIR.IR.Context)
523380
end
524381

525382
Reactant.@reactant_override @noinline function CUDA.cufunction(f::F, tt::TT=Tuple{}; kwargs...) where {F,TT}
526-
@show "recufunction", f, tt
527383
res = Base.@lock CUDA.cufunction_lock begin
528384
# compile the function
529385
cache = compiler_cache(MLIR.IR.context())
530386
source = CUDA.methodinstance(F, tt)
531-
532387
# cuda = CUDA.active_state()
533388
device = nothing # cuda.device
534389
# config = CUDA.compiler_config(device; kwargs...)::CUDA.CUDACompilerConfig
@@ -543,7 +398,6 @@ Reactant.@reactant_override @noinline function CUDA.cufunction(f::F, tt::TT=Tupl
543398
config = CUDA.CompilerConfig(CUDA.PTXCompilerTarget(; cap=llvm_cap, ptx=llvm_ptx, debuginfo), CUDA.CUDACompilerParams(; cap=cuda_cap, ptx=cuda_ptx); kernel, name, always_inline)
544399
CUDA.GPUCompiler.cached_compilation(cache, source, config, compile, link)
545400
end
546-
@show res
547401
res
548402
end
549403

0 commit comments

Comments
 (0)