From 70115c76b074f8a9c44ef13ff444c3b287609d46 Mon Sep 17 00:00:00 2001 From: Valentin Churavy Date: Mon, 18 Aug 2025 15:02:20 +0200 Subject: [PATCH 1/2] Fix deferred_codegen registration Co-authored-by: Gabriel Baraldi --- src/GPUCompiler.jl | 22 ++++++++++++++++++++++ src/driver.jl | 3 ++- test/utils.jl | 2 ++ 3 files changed, 26 insertions(+), 1 deletion(-) diff --git a/src/GPUCompiler.jl b/src/GPUCompiler.jl index 5d337e5e..64884961 100644 --- a/src/GPUCompiler.jl +++ b/src/GPUCompiler.jl @@ -66,6 +66,28 @@ function __init__() global compile_cache = dir Tracy.@register_tracepoints() + + # Register deferred_codegen as a global function so that it can be called with `ccall("extern deferred_codegen"` + @dispose jljit=JuliaOJIT() begin + jd = JITDylib(jljit) + + address = LLVM.API.LLVMOrcJITTargetAddress( + reinterpret(UInt, @cfunction(deferred_codegen, Ptr{Cvoid}, (Ptr{Cvoid},)))) + flags = LLVM.API.LLVMJITSymbolFlags( + LLVM.API.LLVMJITSymbolGenericFlagsExported, 0) + name = mangle(jljit, "deferred_codegen") + symbol = LLVM.API.LLVMJITEvaluatedSymbol(address, flags) + map = if LLVM.version() >= v"15" + LLVM.API.LLVMOrcCSymbolMapPair(name, symbol) + else + LLVM.API.LLVMJITCSymbolMapPair(name, symbol) + end + + mu = LLVM.absolute_symbols(Ref(map)) + LLVM.define(jd, mu) + addr = lookup(jljit, "deferred_codegen") + @assert addr != C_NULL "Failed to register deferred_codegen" + end end end # module diff --git a/src/driver.jl b/src/driver.jl index f6106118..54f6b65c 100644 --- a/src/driver.jl +++ b/src/driver.jl @@ -129,7 +129,8 @@ const deferred_codegen_jobs = Dict{Int, Any}() # We make this function explicitly callable so that we can drive OrcJIT's # lazy compilation from, while also enabling recursive compilation. -Base.@ccallable Ptr{Cvoid} function deferred_codegen(ptr::Ptr{Cvoid}) +# Julia 1.11 and co broke @ccallable so we have to do this manually in __init__ +function deferred_codegen(ptr::Ptr{Cvoid})::Ptr{Cvoid} ptr end diff --git a/test/utils.jl b/test/utils.jl index 595a748e..3b742795 100644 --- a/test/utils.jl +++ b/test/utils.jl @@ -190,4 +190,6 @@ end @testset "Mock Enzyme" begin Enzyme.deferred_codegen_id(typeof(identity), Tuple{Vector{Float64}}) + # Check that we can call this function from the CPU, to support deferred codegen for Enzyme. + @test ccall("extern deferred_codegen", llvmcall, UInt, (UInt,), 3) == 3 end From 5e2a80dc2a2a1cfef28558c1cb27816df4aa9b55 Mon Sep 17 00:00:00 2001 From: Valentin Churavy Date: Wed, 24 Sep 2025 19:07:54 +0200 Subject: [PATCH 2/2] move code next to driver --- src/GPUCompiler.jl | 23 +---------------------- src/driver.jl | 29 ++++++++++++++++++++++++++++- 2 files changed, 29 insertions(+), 23 deletions(-) diff --git a/src/GPUCompiler.jl b/src/GPUCompiler.jl index 64884961..ec33e9a7 100644 --- a/src/GPUCompiler.jl +++ b/src/GPUCompiler.jl @@ -66,28 +66,7 @@ function __init__() global compile_cache = dir Tracy.@register_tracepoints() - - # Register deferred_codegen as a global function so that it can be called with `ccall("extern deferred_codegen"` - @dispose jljit=JuliaOJIT() begin - jd = JITDylib(jljit) - - address = LLVM.API.LLVMOrcJITTargetAddress( - reinterpret(UInt, @cfunction(deferred_codegen, Ptr{Cvoid}, (Ptr{Cvoid},)))) - flags = LLVM.API.LLVMJITSymbolFlags( - LLVM.API.LLVMJITSymbolGenericFlagsExported, 0) - name = mangle(jljit, "deferred_codegen") - symbol = LLVM.API.LLVMJITEvaluatedSymbol(address, flags) - map = if LLVM.version() >= v"15" - LLVM.API.LLVMOrcCSymbolMapPair(name, symbol) - else - LLVM.API.LLVMJITCSymbolMapPair(name, symbol) - end - - mu = LLVM.absolute_symbols(Ref(map)) - LLVM.define(jd, mu) - addr = lookup(jljit, "deferred_codegen") - @assert addr != C_NULL "Failed to register deferred_codegen" - end + register_deferred_codegen() end end # module diff --git a/src/driver.jl b/src/driver.jl index 54f6b65c..8dbff91a 100644 --- a/src/driver.jl +++ b/src/driver.jl @@ -129,7 +129,7 @@ const deferred_codegen_jobs = Dict{Int, Any}() # We make this function explicitly callable so that we can drive OrcJIT's # lazy compilation from, while also enabling recursive compilation. -# Julia 1.11 and co broke @ccallable so we have to do this manually in __init__ +# see `register_deferred_codegen` function deferred_codegen(ptr::Ptr{Cvoid})::Ptr{Cvoid} ptr end @@ -150,6 +150,33 @@ end end end +# Register deferred_codegen as a global function so that it can be called with `ccall("extern deferred_codegen"` +# Called from __init__ +# On 1.11+ this is needed due to a Julia bug that drops the pointer when code-coverage is enabled. +function register_deferred_codegen() + @dispose jljit=JuliaOJIT() begin + jd = JITDylib(jljit) + + address = LLVM.API.LLVMOrcJITTargetAddress( + reinterpret(UInt, @cfunction(deferred_codegen, Ptr{Cvoid}, (Ptr{Cvoid},)))) + flags = LLVM.API.LLVMJITSymbolFlags( + LLVM.API.LLVMJITSymbolGenericFlagsExported, 0) + name = mangle(jljit, "deferred_codegen") + symbol = LLVM.API.LLVMJITEvaluatedSymbol(address, flags) + map = if LLVM.version() >= v"15" + LLVM.API.LLVMOrcCSymbolMapPair(name, symbol) + else + LLVM.API.LLVMJITCSymbolMapPair(name, symbol) + end + + mu = LLVM.absolute_symbols(Ref(map)) + LLVM.define(jd, mu) + addr = lookup(jljit, "deferred_codegen") + @assert addr != C_NULL "Failed to register deferred_codegen" + end + return nothing +end + const __llvm_initialized = Ref(false) @locked function emit_llvm(@nospecialize(job::CompilerJob); kwargs...)