Skip to content

Commit 477be80

Browse files
vchuravygbaraldi
andcommitted
Fix deferred_codegen registration
Co-authored-by: Gabriel Baraldi <[email protected]>
1 parent 32b4fc8 commit 477be80

File tree

3 files changed

+26
-1
lines changed

3 files changed

+26
-1
lines changed

src/GPUCompiler.jl

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -66,6 +66,28 @@ function __init__()
6666
global compile_cache = dir
6767

6868
Tracy.@register_tracepoints()
69+
70+
# Register deferred_codegen as a global function so that it can be called with `ccall("extern deferred_codegen"`
71+
@dispose jljit=JuliaOJIT() begin
72+
jd = JITDylib(jljit)
73+
74+
address = LLVM.API.LLVMOrcJITTargetAddress(
75+
reinterpret(UInt, @cfunction(deferred_codegen, Ptr{Cvoid}, (Ptr{Cvoid},))))
76+
flags = LLVM.API.LLVMJITSymbolFlags(
77+
LLVM.API.LLVMJITSymbolGenericFlagsExported, 0)
78+
name = mangle(jljit, "deferred_codegen")
79+
symbol = LLVM.API.LLVMJITEvaluatedSymbol(address, flags)
80+
map = if LLVM.version() >= v"15"
81+
LLVM.API.LLVMOrcCSymbolMapPair(name, symbol)
82+
else
83+
LLVM.API.LLVMJITCSymbolMapPair(name, symbol)
84+
end
85+
86+
mu = LLVM.absolute_symbols(Ref(map))
87+
LLVM.define(jd, mu)
88+
addr = lookup(jljit, "deferred_codegen")
89+
@assert addr != C_NULL "Failed to register deferred_codegen"
90+
end
6991
end
7092

7193
end # module

src/driver.jl

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -129,7 +129,8 @@ const deferred_codegen_jobs = Dict{Int, Any}()
129129

130130
# We make this function explicitly callable so that we can drive OrcJIT's
131131
# lazy compilation from, while also enabling recursive compilation.
132-
Base.@ccallable Ptr{Cvoid} function deferred_codegen(ptr::Ptr{Cvoid})
132+
# Julia 1.11 and co broke @ccallable so we have to do this manually in __init__
133+
function deferred_codegen(ptr::Ptr{Cvoid})::Ptr{Cvoid}
133134
ptr
134135
end
135136

test/utils.jl

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -190,4 +190,6 @@ end
190190

191191
@testset "Mock Enzyme" begin
192192
Enzyme.deferred_codegen_id(typeof(identity), Tuple{Vector{Float64}})
193+
# Check that we can call this function from the CPU, to support deferred codegen for Enzyme.
194+
@test ccall("extern deferred_codegen", llvmcall, UInt, (UInt,), 3) == 3
193195
end

0 commit comments

Comments
 (0)