Skip to content

Commit c7e3274

Browse files
committed
Add new deferred compilation mechanism.
1 parent 316668b commit c7e3274

File tree

4 files changed

+298
-36
lines changed

4 files changed

+298
-36
lines changed

src/driver.jl

Lines changed: 111 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,41 @@ function JuliaContext(f; kwargs...)
3939
end
4040

4141

42+
## deferred compilation
43+
44+
function var"gpuc.deferred" end
45+
46+
# old, deprecated mechanism slated for removal once Enzyme is updated to the new intrinsic
47+
begin
48+
# primitive mechanism for deferred compilation, for implementing CUDA dynamic parallelism.
49+
# this could both be generalized (e.g. supporting actual function calls, instead of
50+
# returning a function pointer), and be integrated with the nonrecursive codegen.
51+
const deferred_codegen_jobs = Dict{Int, Any}()
52+
53+
# We make this function explicitly callable so that we can drive OrcJIT's
54+
# lazy compilation from, while also enabling recursive compilation.
55+
Base.@ccallable Ptr{Cvoid} function deferred_codegen(ptr::Ptr{Cvoid})
56+
ptr
57+
end
58+
59+
@generated function deferred_codegen(::Val{ft}, ::Val{tt}) where {ft,tt}
60+
id = length(deferred_codegen_jobs) + 1
61+
deferred_codegen_jobs[id] = (; ft, tt)
62+
# don't bother looking up the method instance, as we'll do so again during codegen
63+
# using the world age of the parent.
64+
#
65+
# this also works around an issue on <1.10, where we don't know the world age of
66+
# generated functions so use the current world counter, which may be too new
67+
# for the world we're compiling for.
68+
69+
quote
70+
# TODO: add an edge to this method instance to support method redefinitions
71+
ccall("extern deferred_codegen", llvmcall, Ptr{Cvoid}, (Int,), $id)
72+
end
73+
end
74+
end
75+
76+
4277
## compiler entrypoint
4378

4479
export compile
@@ -127,33 +162,6 @@ function codegen(output::Symbol, @nospecialize(job::CompilerJob); toplevel::Bool
127162
error("Unknown compilation output $output")
128163
end
129164

130-
# primitive mechanism for deferred compilation, for implementing CUDA dynamic parallelism.
131-
# this could both be generalized (e.g. supporting actual function calls, instead of
132-
# returning a function pointer), and be integrated with the nonrecursive codegen.
133-
const deferred_codegen_jobs = Dict{Int, Any}()
134-
135-
# We make this function explicitly callable so that we can drive OrcJIT's
136-
# lazy compilation from, while also enabling recursive compilation.
137-
Base.@ccallable Ptr{Cvoid} function deferred_codegen(ptr::Ptr{Cvoid})
138-
ptr
139-
end
140-
141-
@generated function deferred_codegen(::Val{ft}, ::Val{tt}) where {ft,tt}
142-
id = length(deferred_codegen_jobs) + 1
143-
deferred_codegen_jobs[id] = (; ft, tt)
144-
# don't bother looking up the method instance, as we'll do so again during codegen
145-
# using the world age of the parent.
146-
#
147-
# this also works around an issue on <1.10, where we don't know the world age of
148-
# generated functions so use the current world counter, which may be too new
149-
# for the world we're compiling for.
150-
151-
quote
152-
# TODO: add an edge to this method instance to support method redefinitions
153-
ccall("extern deferred_codegen", llvmcall, Ptr{Cvoid}, (Int,), $id)
154-
end
155-
end
156-
157165
const __llvm_initialized = Ref(false)
158166

159167
@locked function emit_llvm(@nospecialize(job::CompilerJob); toplevel::Bool,
@@ -183,9 +191,82 @@ const __llvm_initialized = Ref(false)
183191
entry = finish_module!(job, ir, entry)
184192

185193
# deferred code generation
186-
has_deferred_jobs = toplevel && !only_entry && haskey(functions(ir), "deferred_codegen")
194+
run_optimization_for_deferred = false
195+
if haskey(functions(ir), "gpuc.lookup")
196+
run_optimization_for_deferred = true
197+
dyn_marker = functions(ir)["gpuc.lookup"]
198+
199+
# gpuc.deferred is lowered to a gpuc.lookup foreigncall, so we need to extract the
200+
# target method instance from the LLVM IR
201+
# TODO: drive deferred compilation from the Julia IR instead
202+
function find_base_object(val)
203+
while true
204+
if val isa ConstantExpr && (opcode(val) == LLVM.API.LLVMIntToPtr ||
205+
opcode(val) == LLVM.API.LLVMBitCast ||
206+
opcode(val) == LLVM.API.LLVMAddrSpaceCast)
207+
val = first(operands(val))
208+
elseif val isa LLVM.IntToPtrInst ||
209+
val isa LLVM.BitCastInst ||
210+
val isa LLVM.AddrSpaceCastInst
211+
val = first(operands(val))
212+
elseif val isa LLVM.LoadInst
213+
# In 1.11+ we no longer embed integer constants directly.
214+
gv = first(operands(val))
215+
if gv isa LLVM.GlobalValue
216+
val = LLVM.initializer(gv)
217+
continue
218+
end
219+
break
220+
else
221+
break
222+
end
223+
end
224+
return val
225+
end
226+
227+
worklist = Dict{Any, Vector{LLVM.CallInst}}()
228+
for use in uses(dyn_marker)
229+
# decode the call
230+
call = user(use)::LLVM.CallInst
231+
dyn_mi_inst = find_base_object(operands(call)[1])
232+
@compiler_assert isa(dyn_mi_inst, LLVM.ConstantInt) job
233+
dyn_mi = Base.unsafe_pointer_to_objref(
234+
convert(Ptr{Cvoid}, convert(Int, dyn_mi_inst)))
235+
push!(get!(worklist, dyn_mi, LLVM.CallInst[]), call)
236+
end
237+
238+
for dyn_mi in keys(worklist)
239+
dyn_fn_name = compiled[dyn_mi].specfunc
240+
dyn_fn = functions(ir)[dyn_fn_name]
241+
242+
# insert a pointer to the function everywhere the entry is used
243+
T_ptr = convert(LLVMType, Ptr{Cvoid})
244+
for call in worklist[dyn_mi]
245+
@dispose builder=IRBuilder() begin
246+
position!(builder, call)
247+
fptr = if LLVM.version() >= v"17"
248+
T_ptr = LLVM.PointerType()
249+
bitcast!(builder, dyn_fn, T_ptr)
250+
elseif VERSION >= v"1.12.0-DEV.225"
251+
T_ptr = LLVM.PointerType(LLVM.Int8Type())
252+
bitcast!(builder, dyn_fn, T_ptr)
253+
else
254+
ptrtoint!(builder, dyn_fn, T_ptr)
255+
end
256+
replace_uses!(call, fptr)
257+
end
258+
unsafe_delete!(LLVM.parent(call), call)
259+
end
260+
end
261+
262+
# all deferred compilations should have been resolved
263+
@compiler_assert isempty(uses(dyn_marker)) job
264+
unsafe_delete!(ir, dyn_marker)
265+
end
266+
## old, deprecated implementation
187267
jobs = Dict{CompilerJob, String}(job => entry_fn)
188-
if has_deferred_jobs
268+
if toplevel && !only_entry && haskey(functions(ir), "deferred_codegen")
269+
run_optimization_for_deferred = true
189270
dyn_marker = functions(ir)["deferred_codegen"]
190271

191272
# iterative compilation (non-recursive)
@@ -194,7 +275,6 @@ const __llvm_initialized = Ref(false)
194275
changed = false
195276

196277
# find deferred compiler
197-
# TODO: recover this information earlier, from the Julia IR
198278
worklist = Dict{CompilerJob, Vector{LLVM.CallInst}}()
199279
for use in uses(dyn_marker)
200280
# decode the call
@@ -317,7 +397,7 @@ const __llvm_initialized = Ref(false)
317397
# deferred codegen has some special optimization requirements,
318398
# which also need to happen _after_ regular optimization.
319399
# XXX: make these part of the optimizer pipeline?
320-
if has_deferred_jobs
400+
if run_optimization_for_deferred
321401
@dispose pb=NewPMPassBuilder() begin
322402
add!(pb, NewPMFunctionPassManager()) do fpm
323403
add!(fpm, InstCombinePass())

src/irgen.jl

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -80,6 +80,17 @@ function irgen(@nospecialize(job::CompilerJob))
8080
compiled[job.source] =
8181
(; compiled[job.source].ci, func, specfunc)
8282

83+
# Earlier we sanitize global names, this invalidates the
84+
# func, specfunc names safed in compiled. Update the names now,
85+
# such that when when use the compiled mappings to lookup the
86+
# llvm function for a methodinstance (deferred codegen) we have
87+
# valid targets.
88+
for mi in keys(compiled)
89+
mi == job.source && continue
90+
ci, func, specfunc = compiled[mi]
91+
compiled[mi] = (; ci, func=safe_name(func), specfunc=safe_name(specfunc))
92+
end
93+
8394
# minimal required optimization
8495
@timeit_debug to "rewrite" begin
8596
if job.config.kernel && needs_byval(job)

0 commit comments

Comments
 (0)