Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions src/GPUCompiler.jl
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,7 @@ function __init__()
global compile_cache = dir

Tracy.@register_tracepoints()
register_deferred_codegen()
end

end # module
30 changes: 29 additions & 1 deletion src/driver.jl
Original file line number Diff line number Diff line change
Expand Up @@ -129,7 +129,8 @@ const deferred_codegen_jobs = Dict{Int, Any}()

# We make this function explicitly callable so that we can drive OrcJIT's
# lazy compilation from, while also enabling recursive compilation.
Base.@ccallable Ptr{Cvoid} function deferred_codegen(ptr::Ptr{Cvoid})
# see `register_deferred_codegen`
function deferred_codegen(ptr::Ptr{Cvoid})::Ptr{Cvoid}
ptr
end

Expand All @@ -149,6 +150,33 @@ end
end
end

# Register deferred_codegen as a global function so that it can be called with `ccall("extern deferred_codegen"`
# Called from __init__
# On 1.11+ this is needed due to a Julia bug that drops the pointer when code-coverage is enabled.
function register_deferred_codegen()
@dispose jljit=JuliaOJIT() begin
jd = JITDylib(jljit)

address = LLVM.API.LLVMOrcJITTargetAddress(
reinterpret(UInt, @cfunction(deferred_codegen, Ptr{Cvoid}, (Ptr{Cvoid},))))
flags = LLVM.API.LLVMJITSymbolFlags(
LLVM.API.LLVMJITSymbolGenericFlagsExported, 0)
name = mangle(jljit, "deferred_codegen")
symbol = LLVM.API.LLVMJITEvaluatedSymbol(address, flags)
map = if LLVM.version() >= v"15"
LLVM.API.LLVMOrcCSymbolMapPair(name, symbol)
else
LLVM.API.LLVMJITCSymbolMapPair(name, symbol)
end

mu = LLVM.absolute_symbols(Ref(map))
LLVM.define(jd, mu)
addr = lookup(jljit, "deferred_codegen")
@assert addr != C_NULL "Failed to register deferred_codegen"
end
Copy link
Member

@gbaraldi gbaraldi Sep 24, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should we ccall it here just to check if it's working? I guess the address is enough but not sure. I guess the test is enough

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah I think the ccall would not add more information, and if failed lead to a more confusing error.

return nothing
end

const __llvm_initialized = Ref(false)

@locked function emit_llvm(@nospecialize(job::CompilerJob); kwargs...)
Expand Down
2 changes: 2 additions & 0 deletions test/utils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -190,4 +190,6 @@ end

@testset "Mock Enzyme" begin
Enzyme.deferred_codegen_id(typeof(identity), Tuple{Vector{Float64}})
# Check that we can call this function from the CPU, to support deferred codegen for Enzyme.
@test ccall("extern deferred_codegen", llvmcall, UInt, (UInt,), 3) == 3
end
Loading