From 33273fe63ab68d01265c0cbda12c70c15da4ad93 Mon Sep 17 00:00:00 2001 From: Tim Besard Date: Tue, 7 Oct 2025 11:53:05 +0200 Subject: [PATCH 1/8] Introduce a new interface to intercept the linked module. --- src/interface.jl | 3 +++ 1 file changed, 3 insertions(+) diff --git a/src/interface.jl b/src/interface.jl index 157f1977..8d39e963 100644 --- a/src/interface.jl +++ b/src/interface.jl @@ -354,6 +354,9 @@ link_libraries!(@nospecialize(job::CompilerJob), mod::LLVM.Module, finish_module!(@nospecialize(job::CompilerJob), mod::LLVM.Module, entry::LLVM.Function) = entry +# finalization of linked modules, after deferred codegen but before optimization +finish_linked_module!(@nospecialize(job::CompilerJob), mod::LLVM.Module) = return + # post-Julia optimization processing of the module optimize_module!(@nospecialize(job::CompilerJob), mod::LLVM.Module) = return From bcb653aeec75af636fbe84a3fdc96dbfff557ec7 Mon Sep 17 00:00:00 2001 From: Tim Besard Date: Tue, 7 Oct 2025 11:54:17 +0200 Subject: [PATCH 2/8] Deferred kernels: eagerly convert indirect calls to direct ones. Also get rid of unused metadata during argument conversions so that we only have to handle instructions. --- src/driver.jl | 13 +++++++++++++ 1 file changed, 13 insertions(+) diff --git a/src/driver.jl b/src/driver.jl index 8dbff91a..cbc0bf91 100644 --- a/src/driver.jl +++ b/src/driver.jl @@ -282,6 +282,19 @@ const __llvm_initialized = Ref(false) erase!(call) end end + + # minimal optimization to convert the inttoptr/call into a direct call + @dispose pb=NewPMPassBuilder() begin + add!(pb, NewPMFunctionPassManager()) do fpm + add!(fpm, InstCombinePass()) + end + run!(pb, ir, llvm_machine(job.config.target)) + end + ## XXX: LLVM often leaves behind unused constant expressions containing function + ## pointer bitcasts we just optimized away, so prune those manually. + for f in functions(ir) + prune_constexpr_uses!(f) + end end # all deferred compilations should have been resolved From 69f9aa0221ba055f777b883cdfcca904cbdcfa15 Mon Sep 17 00:00:00 2001 From: Tim Besard Date: Tue, 7 Oct 2025 11:54:53 +0200 Subject: [PATCH 3/8] Invoke the new interface method after kernel metadata is set. --- src/driver.jl | 99 +++++++++++++++++++++++++-------------------------- 1 file changed, 49 insertions(+), 50 deletions(-) diff --git a/src/driver.jl b/src/driver.jl index cbc0bf91..a1e3495e 100644 --- a/src/driver.jl +++ b/src/driver.jl @@ -325,10 +325,18 @@ const __llvm_initialized = Ref(false) end @tracepoint "IR post-processing" begin - # mark everything internal except for entrypoints and any exported - # global variables. this makes sure that the optimizer can, e.g., - # rewrite function signatures. + # mark the kernel entry-point functions (optimization may need it) + if job.config.kernel + push!(metadata(ir)["julia.kernel"], MDNode([entry])) + + # IDEA: save all jobs, not only kernels, and save other attributes + # so that we can reconstruct the CompileJob instead of setting it globally + end + if job.config.toplevel + # mark everything internal except for entrypoints and any exported + # global variables. this makes sure that the optimizer can, e.g., + # rewrite function signatures. preserved_gvs = collect(values(jobs)) for gvar in globals(ir) if linkage(gvar) == LLVM.API.LLVMExternalLinkage @@ -344,34 +352,41 @@ const __llvm_initialized = Ref(false) run!(pm, ir) end end - end - - # mark the kernel entry-point functions (optimization may need it) - if job.config.kernel - push!(metadata(ir)["julia.kernel"], MDNode([entry])) - - # IDEA: save all jobs, not only kernels, and save other attributes - # so that we can reconstruct the CompileJob instead of setting it globally - end - if job.config.toplevel && job.config.optimize - @tracepoint "optimization" begin - optimize!(job, ir; job.config.opt_level) + finish_linked_module!(job, ir) + + if job.config.optimize + @tracepoint "optimization" begin + optimize!(job, ir; job.config.opt_level) + + # deferred codegen has some special optimization requirements, + # which also need to happen _after_ regular optimization. + # XXX: make these part of the optimizer pipeline? + if has_deferred_jobs + @dispose pb=NewPMPassBuilder() begin + add!(pb, NewPMFunctionPassManager()) do fpm + add!(fpm, InstCombinePass()) + end + add!(pb, AlwaysInlinerPass()) + add!(pb, NewPMFunctionPassManager()) do fpm + add!(fpm, SROAPass()) + add!(fpm, GVNPass()) + end + add!(pb, MergeFunctionsPass()) + run!(pb, ir, llvm_machine(job.config.target)) + end + end + end + end - # deferred codegen has some special optimization requirements, - # which also need to happen _after_ regular optimization. - # XXX: make these part of the optimizer pipeline? - if has_deferred_jobs + if job.config.cleanup + @tracepoint "clean-up" begin @dispose pb=NewPMPassBuilder() begin - add!(pb, NewPMFunctionPassManager()) do fpm - add!(fpm, InstCombinePass()) - end - add!(pb, AlwaysInlinerPass()) - add!(pb, NewPMFunctionPassManager()) do fpm - add!(fpm, SROAPass()) - add!(fpm, GVNPass()) - end - add!(pb, MergeFunctionsPass()) + add!(pb, RecomputeGlobalsAAPass()) + add!(pb, GlobalOptPass()) + add!(pb, GlobalDCEPass()) + add!(pb, StripDeadPrototypesPass()) + add!(pb, ConstantMergePass()) run!(pb, ir, llvm_machine(job.config.target)) end end @@ -379,29 +394,13 @@ const __llvm_initialized = Ref(false) # optimization may have replaced functions, so look the entry point up again entry = functions(ir)[entry_fn] - end - if job.config.toplevel && job.config.cleanup - @tracepoint "clean-up" begin - @dispose pb=NewPMPassBuilder() begin - add!(pb, RecomputeGlobalsAAPass()) - add!(pb, GlobalOptPass()) - add!(pb, GlobalDCEPass()) - add!(pb, StripDeadPrototypesPass()) - add!(pb, ConstantMergePass()) - run!(pb, ir, llvm_machine(job.config.target)) - end - end - end - - # finish the module - # - # we want to finish the module after optimization, so we cannot do so - # during deferred code generation. instead, process the deferred jobs - # here. - if job.config.toplevel + # finish the module + # + # we want to finish the module after optimization, so we cannot do so + # during deferred code generation. instead, process the deferred jobs + # here. entry = finish_ir!(job, ir, entry) - for (job′, fn′) in jobs job′ == job && continue finish_ir!(job′, ir, functions(ir)[fn′]) @@ -422,7 +421,7 @@ const __llvm_initialized = Ref(false) end if job.config.toplevel && job.config.validate - @tracepoint "Validation" begin + @tracepoint "validation" begin check_ir(job, ir) end end From a4c9a475cedea7fdd724c24d09aeaafac8d06603 Mon Sep 17 00:00:00 2001 From: Tim Besard Date: Tue, 7 Oct 2025 11:55:08 +0200 Subject: [PATCH 4/8] Ensure kernel metadata is correct on kernel replacement. --- src/irgen.jl | 1 + src/metal.jl | 1 + 2 files changed, 2 insertions(+) diff --git a/src/irgen.jl b/src/irgen.jl index 2d19961a..766ad384 100644 --- a/src/irgen.jl +++ b/src/irgen.jl @@ -1054,6 +1054,7 @@ function add_input_arguments!(@nospecialize(job::CompilerJob), mod::LLVM.Module, for (f, new_f) in workmap rewrite_uses!(f, new_f) @assert isempty(uses(f)) + replace_metadata_uses!(f, new_f) erase!(f) end diff --git a/src/metal.jl b/src/metal.jl index 200af854..ab0bc7fe 100644 --- a/src/metal.jl +++ b/src/metal.jl @@ -497,6 +497,7 @@ function pass_by_reference!(@nospecialize(job::CompilerJob), mod::LLVM.Module, f # NOTE: if we ever have legitimate uses of the old function, create a shim instead fn = LLVM.name(f) @assert isempty(uses(f)) + replace_metadata_uses!(f, new_f) erase!(f) LLVM.name!(new_f, fn) From 3bb429ac011b7b703088a7df40ddf8ec119bf9f8 Mon Sep 17 00:00:00 2001 From: Tim Besard Date: Tue, 7 Oct 2025 11:55:45 +0200 Subject: [PATCH 5/8] Run Metal's ABI conversions on the linked module. This makes it possible to use intrinsics in deferred compilations. --- src/metal.jl | 19 +++++++++++-------- 1 file changed, 11 insertions(+), 8 deletions(-) diff --git a/src/metal.jl b/src/metal.jl index ab0bc7fe..3badedf5 100644 --- a/src/metal.jl +++ b/src/metal.jl @@ -47,13 +47,16 @@ runtime_slug(job::CompilerJob{MetalCompilerTarget}) = "metal-macos$(job.config.t isintrinsic(@nospecialize(job::CompilerJob{MetalCompilerTarget}), fn::String) = return startswith(fn, "air.") -function finish_module!(@nospecialize(job::CompilerJob{MetalCompilerTarget}), mod::LLVM.Module, entry::LLVM.Function) - entry_fn = LLVM.name(entry) - - # update calling conventions - if job.config.kernel - entry = pass_by_reference!(job, mod, entry) - entry = add_input_arguments!(job, mod, entry, kernel_intrinsics) +function finish_linked_module!(@nospecialize(job::CompilerJob{MetalCompilerTarget}), mod::LLVM.Module) + if haskey(metadata(mod), "julia.kernel") + kernels_md = metadata(mod)["julia.kernel"] + for kernel_md in operands(kernels_md) + f = LLVM.Value(operands(kernel_md)[1])::LLVM.Function + + # update calling conventions + f = pass_by_reference!(job, mod, f) + f = add_input_arguments!(job, mod, f, kernel_intrinsics) + end end # emit the AIR and Metal version numbers as constants in the module. this makes it @@ -83,7 +86,7 @@ function finish_module!(@nospecialize(job::CompilerJob{MetalCompilerTarget}), mo run!(pb, mod) end - return functions(mod)[entry_fn] + return end function validate_ir(job::CompilerJob{MetalCompilerTarget}, mod::LLVM.Module) From e24219021bb18e89a2685573a55397368ebe9a27 Mon Sep 17 00:00:00 2001 From: Tim Besard Date: Tue, 7 Oct 2025 12:49:05 +0200 Subject: [PATCH 6/8] Add helpers for iterating kernels in a module. --- src/driver.jl | 5 +---- src/irgen.jl | 8 +------- src/metal.jl | 13 ++++--------- src/utils.jl | 27 +++++++++++++++++++++++++++ 4 files changed, 33 insertions(+), 20 deletions(-) diff --git a/src/driver.jl b/src/driver.jl index a1e3495e..ac30c7ab 100644 --- a/src/driver.jl +++ b/src/driver.jl @@ -327,10 +327,7 @@ const __llvm_initialized = Ref(false) @tracepoint "IR post-processing" begin # mark the kernel entry-point functions (optimization may need it) if job.config.kernel - push!(metadata(ir)["julia.kernel"], MDNode([entry])) - - # IDEA: save all jobs, not only kernels, and save other attributes - # so that we can reconstruct the CompileJob instead of setting it globally + mark_kernel!(entry) end if job.config.toplevel diff --git a/src/irgen.jl b/src/irgen.jl index 766ad384..7a2071c7 100644 --- a/src/irgen.jl +++ b/src/irgen.jl @@ -526,18 +526,12 @@ function add_kernel_state!(mod::LLVM.Module) state_intr = kernel_state_intr(mod, T_state) state_intr_ft = LLVM.FunctionType(T_state) - kernels = [] - kernels_md = metadata(mod)["julia.kernel"] - for kernel_md in operands(kernels_md) - push!(kernels, Value(operands(kernel_md)[1])) - end - # determine which functions need a kernel state argument # # previously, we add the argument to every function and relied on unused arg elim to # clean-up the IR. however, some libraries do Funny Stuff, e.g., libdevice bitcasting # function pointers. such IR is hard to rewrite, so instead be more conservative. - worklist = Set{LLVM.Function}([state_intr, kernels...]) + worklist = Set{LLVM.Function}([state_intr, kernels(mod)...]) worklist_length = 0 while worklist_length != length(worklist) # iteratively discover functions that use the intrinsic or any function calling it diff --git a/src/metal.jl b/src/metal.jl index 3badedf5..d3a83d61 100644 --- a/src/metal.jl +++ b/src/metal.jl @@ -48,15 +48,10 @@ isintrinsic(@nospecialize(job::CompilerJob{MetalCompilerTarget}), fn::String) = return startswith(fn, "air.") function finish_linked_module!(@nospecialize(job::CompilerJob{MetalCompilerTarget}), mod::LLVM.Module) - if haskey(metadata(mod), "julia.kernel") - kernels_md = metadata(mod)["julia.kernel"] - for kernel_md in operands(kernels_md) - f = LLVM.Value(operands(kernel_md)[1])::LLVM.Function - - # update calling conventions - f = pass_by_reference!(job, mod, f) - f = add_input_arguments!(job, mod, f, kernel_intrinsics) - end + for f in kernels(mod) + # update calling conventions + f = pass_by_reference!(job, mod, f) + f = add_input_arguments!(job, mod, f, kernel_intrinsics) end # emit the AIR and Metal version numbers as constants in the module. this makes it diff --git a/src/utils.jl b/src/utils.jl index bad2a307..095f22dc 100644 --- a/src/utils.jl +++ b/src/utils.jl @@ -155,3 +155,30 @@ function prune_constexpr_uses!(root::LLVM.Value) end end end + + +## kernel metadata handling + +# kernels are encoded in the IR using the julia.kernel metadata. + +# IDEA: don't only mark kernels, but all jobs, and save all attributes of the CompileJob +# so that we can reconstruct the CompileJob instead of setting it globally + +# mark a function as kernel +function mark_kernel!(f::LLVM.Function) + mod = LLVM.parent(f) + push!(metadata(mod)["julia.kernel"], MDNode([f])) + return f +end + +# iterate over all kernels in the module +function kernels(mod::LLVM.Module) + vals = LLVM.Function[] + if haskey(metadata(mod), "julia.kernel") + kernels_md = metadata(mod)["julia.kernel"] + for kernel_md in operands(kernels_md) + push!(vals, LLVM.Value(operands(kernel_md)[1])) + end + end + return vals +end From 0a4d0b759ad186557e87514a8851dc9352a0fab5 Mon Sep 17 00:00:00 2001 From: Tim Besard Date: Tue, 7 Oct 2025 13:29:59 +0200 Subject: [PATCH 7/8] Make add_input_arguments! resilient to constant expressions. --- src/driver.jl | 5 ----- src/irgen.jl | 24 ++++++++++++++++++------ 2 files changed, 18 insertions(+), 11 deletions(-) diff --git a/src/driver.jl b/src/driver.jl index ac30c7ab..950ea272 100644 --- a/src/driver.jl +++ b/src/driver.jl @@ -290,11 +290,6 @@ const __llvm_initialized = Ref(false) end run!(pb, ir, llvm_machine(job.config.target)) end - ## XXX: LLVM often leaves behind unused constant expressions containing function - ## pointer bitcasts we just optimized away, so prune those manually. - for f in functions(ir) - prune_constexpr_uses!(f) - end end # all deferred compilations should have been resolved diff --git a/src/irgen.jl b/src/irgen.jl index 7a2071c7..2d363150 100644 --- a/src/irgen.jl +++ b/src/irgen.jl @@ -935,12 +935,24 @@ function add_input_arguments!(@nospecialize(job::CompilerJob), mod::LLVM.Module, while worklist_length != length(worklist) # iteratively discover functions that use an intrinsic or any function calling it worklist_length = length(worklist) - additions = LLVM.Function[] - for f in worklist, use in uses(f) - inst = user(use)::Instruction - bb = LLVM.parent(inst) - new_f = LLVM.parent(bb) - in(new_f, worklist) || push!(additions, new_f) + additions = Set{LLVM.Function}() + function scan_uses(val) + for use in uses(val) + candidate = user(use) + if isa(candidate, Instruction) + bb = LLVM.parent(candidate) + new_f = LLVM.parent(bb) + in(new_f, worklist) || push!(additions, new_f) + elseif isa(candidate, ConstantExpr) + @safe_info candidate + scan_uses(candidate) + else + error("Don't know how to check uses of $candidate. Please file an issue.") + end + end + end + for f in worklist + scan_uses(f) end for f in additions push!(worklist, f) From 1722ae3fb370a9adf405d77d1ae24b4a324ca695 Mon Sep 17 00:00:00 2001 From: Tim Besard Date: Tue, 7 Oct 2025 13:11:09 +0200 Subject: [PATCH 8/8] Bump version. --- Project.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/Project.toml b/Project.toml index 6aee0967..8761bb41 100644 --- a/Project.toml +++ b/Project.toml @@ -1,6 +1,6 @@ name = "GPUCompiler" uuid = "61eb1bfa-7361-4325-ad38-22787b887f55" -version = "1.7.0" +version = "1.7.1" authors = ["Tim Besard "] [deps]