Skip to content

Commit 9ee6efb

Browse files
committed
move code next to driver
1 parent 477be80 commit 9ee6efb

File tree

2 files changed

+29
-23
lines changed

2 files changed

+29
-23
lines changed

src/GPUCompiler.jl

Lines changed: 1 addition & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -66,28 +66,7 @@ function __init__()
6666
global compile_cache = dir
6767

6868
Tracy.@register_tracepoints()
69-
70-
# Register deferred_codegen as a global function so that it can be called with `ccall("extern deferred_codegen"`
71-
@dispose jljit=JuliaOJIT() begin
72-
jd = JITDylib(jljit)
73-
74-
address = LLVM.API.LLVMOrcJITTargetAddress(
75-
reinterpret(UInt, @cfunction(deferred_codegen, Ptr{Cvoid}, (Ptr{Cvoid},))))
76-
flags = LLVM.API.LLVMJITSymbolFlags(
77-
LLVM.API.LLVMJITSymbolGenericFlagsExported, 0)
78-
name = mangle(jljit, "deferred_codegen")
79-
symbol = LLVM.API.LLVMJITEvaluatedSymbol(address, flags)
80-
map = if LLVM.version() >= v"15"
81-
LLVM.API.LLVMOrcCSymbolMapPair(name, symbol)
82-
else
83-
LLVM.API.LLVMJITCSymbolMapPair(name, symbol)
84-
end
85-
86-
mu = LLVM.absolute_symbols(Ref(map))
87-
LLVM.define(jd, mu)
88-
addr = lookup(jljit, "deferred_codegen")
89-
@assert addr != C_NULL "Failed to register deferred_codegen"
90-
end
69+
register_deferred_codegen()
9170
end
9271

9372
end # module

src/driver.jl

Lines changed: 28 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -129,7 +129,7 @@ const deferred_codegen_jobs = Dict{Int, Any}()
129129

130130
# We make this function explicitly callable so that we can drive OrcJIT's
131131
# lazy compilation from, while also enabling recursive compilation.
132-
# Julia 1.11 and co broke @ccallable so we have to do this manually in __init__
132+
# see `register_deferred_codegen`
133133
function deferred_codegen(ptr::Ptr{Cvoid})::Ptr{Cvoid}
134134
ptr
135135
end
@@ -150,6 +150,33 @@ end
150150
end
151151
end
152152

153+
# Register deferred_codegen as a global function so that it can be called with `ccall("extern deferred_codegen"`
154+
# Called from __init__
155+
# On 1.11+ this is needed due to a Julia bug that drops the pointer when code-coverage is enabled.
156+
function register_deferred_codegen()
157+
@dispose jljit=JuliaOJIT() begin
158+
jd = JITDylib(jljit)
159+
160+
address = LLVM.API.LLVMOrcJITTargetAddress(
161+
reinterpret(UInt, @cfunction(deferred_codegen, Ptr{Cvoid}, (Ptr{Cvoid},))))
162+
flags = LLVM.API.LLVMJITSymbolFlags(
163+
LLVM.API.LLVMJITSymbolGenericFlagsExported, 0)
164+
name = mangle(jljit, "deferred_codegen")
165+
symbol = LLVM.API.LLVMJITEvaluatedSymbol(address, flags)
166+
map = if LLVM.version() >= v"15"
167+
LLVM.API.LLVMOrcCSymbolMapPair(name, symbol)
168+
else
169+
LLVM.API.LLVMJITCSymbolMapPair(name, symbol)
170+
end
171+
172+
mu = LLVM.absolute_symbols(Ref(map))
173+
LLVM.define(jd, mu)
174+
addr = lookup(jljit, "deferred_codegen")
175+
@assert addr != C_NULL "Failed to register deferred_codegen"
176+
end
177+
return nothing
178+
end
179+
153180
const __llvm_initialized = Ref(false)
154181

155182
@locked function emit_llvm(@nospecialize(job::CompilerJob); kwargs...)

0 commit comments

Comments
 (0)