|
43 | 43 |
|
44 | 44 | function var"gpuc.deferred" end
|
45 | 45 |
|
46 |
| -# old, deprecated mechanism slated for removal once Enzyme is updated to the new intrinsic |
47 |
| -begin |
48 |
| - # primitive mechanism for deferred compilation, for implementing CUDA dynamic parallelism. |
49 |
| - # this could both be generalized (e.g. supporting actual function calls, instead of |
50 |
| - # returning a function pointer), and be integrated with the nonrecursive codegen. |
51 |
| - const deferred_codegen_jobs = Dict{Int, Any}() |
52 |
| - |
53 |
| - # We make this function explicitly callable so that we can drive OrcJIT's |
54 |
| - # lazy compilation from, while also enabling recursive compilation. |
55 |
| - Base.@ccallable Ptr{Cvoid} function deferred_codegen(ptr::Ptr{Cvoid}) |
56 |
| - ptr |
57 |
| - end |
58 |
| - |
59 |
| - @generated function deferred_codegen(::Val{ft}, ::Val{tt}) where {ft,tt} |
60 |
| - id = length(deferred_codegen_jobs) + 1 |
61 |
| - deferred_codegen_jobs[id] = (; ft, tt) |
62 |
| - # don't bother looking up the method instance, as we'll do so again during codegen |
63 |
| - # using the world age of the parent. |
64 |
| - # |
65 |
| - # this also works around an issue on <1.10, where we don't know the world age of |
66 |
| - # generated functions so use the current world counter, which may be too new |
67 |
| - # for the world we're compiling for. |
68 |
| - |
69 |
| - quote |
70 |
| - # TODO: add an edge to this method instance to support method redefinitions |
71 |
| - ccall("extern deferred_codegen", llvmcall, Ptr{Cvoid}, (Int,), $id) |
72 |
| - end |
73 |
| - end |
74 |
| -end |
75 |
| - |
76 |
| - |
77 | 46 | ## compiler entrypoint
|
78 | 47 |
|
79 | 48 | export compile
|
@@ -198,7 +167,6 @@ const __llvm_initialized = Ref(false)
|
198 | 167 |
|
199 | 168 | # gpuc.deferred is lowered to a gpuc.lookup foreigncall, so we need to extract the
|
200 | 169 | # target method instance from the LLVM IR
|
201 |
| - # TODO: drive deferred compilation from the Julia IR instead |
202 | 170 | function find_base_object(val)
|
203 | 171 | while true
|
204 | 172 | if val isa ConstantExpr && (opcode(val) == LLVM.API.LLVMIntToPtr ||
|
@@ -263,80 +231,6 @@ const __llvm_initialized = Ref(false)
|
263 | 231 | @compiler_assert isempty(uses(dyn_marker)) job
|
264 | 232 | unsafe_delete!(ir, dyn_marker)
|
265 | 233 | end
|
266 |
| - ## old, deprecated implementation |
267 |
| - jobs = Dict{CompilerJob, String}(job => entry_fn) |
268 |
| - if toplevel && !only_entry && haskey(functions(ir), "deferred_codegen") |
269 |
| - run_optimization_for_deferred = true |
270 |
| - dyn_marker = functions(ir)["deferred_codegen"] |
271 |
| - |
272 |
| - # iterative compilation (non-recursive) |
273 |
| - changed = true |
274 |
| - while changed |
275 |
| - changed = false |
276 |
| - |
277 |
| - # find deferred compiler |
278 |
| - worklist = Dict{CompilerJob, Vector{LLVM.CallInst}}() |
279 |
| - for use in uses(dyn_marker) |
280 |
| - # decode the call |
281 |
| - call = user(use)::LLVM.CallInst |
282 |
| - id = convert(Int, first(operands(call))) |
283 |
| - |
284 |
| - global deferred_codegen_jobs |
285 |
| - dyn_val = deferred_codegen_jobs[id] |
286 |
| - |
287 |
| - # get a job in the appopriate world |
288 |
| - dyn_job = if dyn_val isa CompilerJob |
289 |
| - # trust that the user knows what they're doing |
290 |
| - dyn_val |
291 |
| - else |
292 |
| - ft, tt = dyn_val |
293 |
| - dyn_src = methodinstance(ft, tt, tls_world_age()) |
294 |
| - CompilerJob(dyn_src, job.config) |
295 |
| - end |
296 |
| - |
297 |
| - push!(get!(worklist, dyn_job, LLVM.CallInst[]), call) |
298 |
| - end |
299 |
| - |
300 |
| - # compile and link |
301 |
| - for dyn_job in keys(worklist) |
302 |
| - # cached compilation |
303 |
| - dyn_entry_fn = get!(jobs, dyn_job) do |
304 |
| - dyn_ir, dyn_meta = codegen(:llvm, dyn_job; toplevel=false, |
305 |
| - parent_job=job) |
306 |
| - dyn_entry_fn = LLVM.name(dyn_meta.entry) |
307 |
| - merge!(compiled, dyn_meta.compiled) |
308 |
| - @assert context(dyn_ir) == context(ir) |
309 |
| - link!(ir, dyn_ir) |
310 |
| - changed = true |
311 |
| - dyn_entry_fn |
312 |
| - end |
313 |
| - dyn_entry = functions(ir)[dyn_entry_fn] |
314 |
| - |
315 |
| - # insert a pointer to the function everywhere the entry is used |
316 |
| - T_ptr = convert(LLVMType, Ptr{Cvoid}) |
317 |
| - for call in worklist[dyn_job] |
318 |
| - @dispose builder=IRBuilder() begin |
319 |
| - position!(builder, call) |
320 |
| - fptr = if LLVM.version() >= v"17" |
321 |
| - T_ptr = LLVM.PointerType() |
322 |
| - bitcast!(builder, dyn_entry, T_ptr) |
323 |
| - elseif VERSION >= v"1.12.0-DEV.225" |
324 |
| - T_ptr = LLVM.PointerType(LLVM.Int8Type()) |
325 |
| - bitcast!(builder, dyn_entry, T_ptr) |
326 |
| - else |
327 |
| - ptrtoint!(builder, dyn_entry, T_ptr) |
328 |
| - end |
329 |
| - replace_uses!(call, fptr) |
330 |
| - end |
331 |
| - unsafe_delete!(LLVM.parent(call), call) |
332 |
| - end |
333 |
| - end |
334 |
| - end |
335 |
| - |
336 |
| - # all deferred compilations should have been resolved |
337 |
| - @compiler_assert isempty(uses(dyn_marker)) job |
338 |
| - unsafe_delete!(ir, dyn_marker) |
339 |
| - end |
340 | 234 |
|
341 | 235 | if libraries
|
342 | 236 | # load the runtime outside of a timing block (because it recurses into the compiler)
|
@@ -433,8 +327,8 @@ const __llvm_initialized = Ref(false)
|
433 | 327 | # finish the module
|
434 | 328 | #
|
435 | 329 | # we want to finish the module after optimization, so we cannot do so
|
436 |
| - # during deferred code generation. instead, process the deferred jobs |
437 |
| - # here. |
| 330 | + # during deferred code generation. Instead, process the merged module |
| 331 | + # from all the jobs here. |
438 | 332 | if toplevel
|
439 | 333 | entry = finish_ir!(job, ir, entry)
|
440 | 334 |
|
|
0 commit comments