Skip to content

Commit b54b5e4

Browse files
committed
remove old deferred implementation
1 parent 2fa7871 commit b54b5e4

File tree

2 files changed

+30
-135
lines changed

2 files changed

+30
-135
lines changed

examples/jit.jl

Lines changed: 28 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -116,31 +116,31 @@ function get_trampoline(job)
116116
return addr
117117
end
118118

119-
import GPUCompiler: deferred_codegen_jobs
120-
@generated function deferred_codegen(f::F, ::Val{tt}, ::Val{world}) where {F,tt,world}
121-
# manual version of native_job because we have a function type
122-
source = methodinstance(F, Base.to_tuple_type(tt), world)
123-
target = NativeCompilerTarget(; jlruntime=true, llvm_always_inline=true)
124-
# XXX: do we actually require the Julia runtime?
125-
# with jlruntime=false, we reach an unreachable.
126-
params = TestCompilerParams()
127-
config = CompilerConfig(target, params; kernel=false)
128-
job = CompilerJob(source, config, world)
129-
# XXX: invoking GPUCompiler from a generated function is not allowed!
130-
# for things to work, we need to forward the correct world, at least.
131-
132-
addr = get_trampoline(job)
133-
trampoline = pointer(addr)
134-
id = Base.reinterpret(Int, trampoline)
135-
136-
deferred_codegen_jobs[id] = job
137-
138-
quote
139-
ptr = ccall("extern deferred_codegen", llvmcall, Ptr{Cvoid}, (Ptr{Cvoid},), $trampoline)
140-
assume(ptr != C_NULL)
141-
return ptr
142-
end
143-
end
119+
# import GPUCompiler: deferred_codegen_jobs
120+
# @generated function deferred_codegen(f::F, ::Val{tt}, ::Val{world}) where {F,tt,world}
121+
# # manual version of native_job because we have a function type
122+
# source = methodinstance(F, Base.to_tuple_type(tt), world)
123+
# target = NativeCompilerTarget(; jlruntime=true, llvm_always_inline=true)
124+
# # XXX: do we actually require the Julia runtime?
125+
# # with jlruntime=false, we reach an unreachable.
126+
# params = TestCompilerParams()
127+
# config = CompilerConfig(target, params; kernel=false)
128+
# job = CompilerJob(source, config, world)
129+
# # XXX: invoking GPUCompiler from a generated function is not allowed!
130+
# # for things to work, we need to forward the correct world, at least.
131+
132+
# addr = get_trampoline(job)
133+
# trampoline = pointer(addr)
134+
# id = Base.reinterpret(Int, trampoline)
135+
136+
# deferred_codegen_jobs[id] = job
137+
138+
# quote
139+
# ptr = ccall("extern deferred_codegen", llvmcall, Ptr{Cvoid}, (Ptr{Cvoid},), $trampoline)
140+
# assume(ptr != C_NULL)
141+
# return ptr
142+
# end
143+
# end
144144

145145
@generated function abi_call(f::Ptr{Cvoid}, rt::Type{RT}, tt::Type{T}, func::F, args::Vararg{Any, N}) where {T, RT, F, N}
146146
argtt = tt.parameters[1]
@@ -224,8 +224,9 @@ end
224224
@inline function call_delayed(f::F, args...) where F
225225
tt = Tuple{map(Core.Typeof, args)...}
226226
rt = Core.Compiler.return_type(f, tt)
227-
world = GPUCompiler.tls_world_age()
228-
ptr = deferred_codegen(f, Val(tt), Val(world))
227+
# FIXME: Horrible idea, have `var"gpuc.deferred"` actually do the work
228+
# But that will only be needed here, and in Enzyme...
229+
ptr = GPUCompiler.var"gpuc.deferred"(f, args...)
229230
abi_call(ptr, rt, tt, f, args...)
230231
end
231232

src/driver.jl

Lines changed: 2 additions & 108 deletions
Original file line numberDiff line numberDiff line change
@@ -43,37 +43,6 @@ end
4343

4444
function var"gpuc.deferred" end
4545

46-
# old, deprecated mechanism slated for removal once Enzyme is updated to the new intrinsic
47-
begin
48-
# primitive mechanism for deferred compilation, for implementing CUDA dynamic parallelism.
49-
# this could both be generalized (e.g. supporting actual function calls, instead of
50-
# returning a function pointer), and be integrated with the nonrecursive codegen.
51-
const deferred_codegen_jobs = Dict{Int, Any}()
52-
53-
# We make this function explicitly callable so that we can drive OrcJIT's
54-
# lazy compilation from, while also enabling recursive compilation.
55-
Base.@ccallable Ptr{Cvoid} function deferred_codegen(ptr::Ptr{Cvoid})
56-
ptr
57-
end
58-
59-
@generated function deferred_codegen(::Val{ft}, ::Val{tt}) where {ft,tt}
60-
id = length(deferred_codegen_jobs) + 1
61-
deferred_codegen_jobs[id] = (; ft, tt)
62-
# don't bother looking up the method instance, as we'll do so again during codegen
63-
# using the world age of the parent.
64-
#
65-
# this also works around an issue on <1.10, where we don't know the world age of
66-
# generated functions so use the current world counter, which may be too new
67-
# for the world we're compiling for.
68-
69-
quote
70-
# TODO: add an edge to this method instance to support method redefinitions
71-
ccall("extern deferred_codegen", llvmcall, Ptr{Cvoid}, (Int,), $id)
72-
end
73-
end
74-
end
75-
76-
7746
## compiler entrypoint
7847

7948
export compile
@@ -198,7 +167,6 @@ const __llvm_initialized = Ref(false)
198167

199168
# gpuc.deferred is lowered to a gpuc.lookup foreigncall, so we need to extract the
200169
# target method instance from the LLVM IR
201-
# TODO: drive deferred compilation from the Julia IR instead
202170
function find_base_object(val)
203171
while true
204172
if val isa ConstantExpr && (opcode(val) == LLVM.API.LLVMIntToPtr ||
@@ -263,80 +231,6 @@ const __llvm_initialized = Ref(false)
263231
@compiler_assert isempty(uses(dyn_marker)) job
264232
unsafe_delete!(ir, dyn_marker)
265233
end
266-
## old, deprecated implementation
267-
jobs = Dict{CompilerJob, String}(job => entry_fn)
268-
if toplevel && !only_entry && haskey(functions(ir), "deferred_codegen")
269-
run_optimization_for_deferred = true
270-
dyn_marker = functions(ir)["deferred_codegen"]
271-
272-
# iterative compilation (non-recursive)
273-
changed = true
274-
while changed
275-
changed = false
276-
277-
# find deferred compiler
278-
worklist = Dict{CompilerJob, Vector{LLVM.CallInst}}()
279-
for use in uses(dyn_marker)
280-
# decode the call
281-
call = user(use)::LLVM.CallInst
282-
id = convert(Int, first(operands(call)))
283-
284-
global deferred_codegen_jobs
285-
dyn_val = deferred_codegen_jobs[id]
286-
287-
# get a job in the appopriate world
288-
dyn_job = if dyn_val isa CompilerJob
289-
# trust that the user knows what they're doing
290-
dyn_val
291-
else
292-
ft, tt = dyn_val
293-
dyn_src = methodinstance(ft, tt, tls_world_age())
294-
CompilerJob(dyn_src, job.config)
295-
end
296-
297-
push!(get!(worklist, dyn_job, LLVM.CallInst[]), call)
298-
end
299-
300-
# compile and link
301-
for dyn_job in keys(worklist)
302-
# cached compilation
303-
dyn_entry_fn = get!(jobs, dyn_job) do
304-
dyn_ir, dyn_meta = codegen(:llvm, dyn_job; toplevel=false,
305-
parent_job=job)
306-
dyn_entry_fn = LLVM.name(dyn_meta.entry)
307-
merge!(compiled, dyn_meta.compiled)
308-
@assert context(dyn_ir) == context(ir)
309-
link!(ir, dyn_ir)
310-
changed = true
311-
dyn_entry_fn
312-
end
313-
dyn_entry = functions(ir)[dyn_entry_fn]
314-
315-
# insert a pointer to the function everywhere the entry is used
316-
T_ptr = convert(LLVMType, Ptr{Cvoid})
317-
for call in worklist[dyn_job]
318-
@dispose builder=IRBuilder() begin
319-
position!(builder, call)
320-
fptr = if LLVM.version() >= v"17"
321-
T_ptr = LLVM.PointerType()
322-
bitcast!(builder, dyn_entry, T_ptr)
323-
elseif VERSION >= v"1.12.0-DEV.225"
324-
T_ptr = LLVM.PointerType(LLVM.Int8Type())
325-
bitcast!(builder, dyn_entry, T_ptr)
326-
else
327-
ptrtoint!(builder, dyn_entry, T_ptr)
328-
end
329-
replace_uses!(call, fptr)
330-
end
331-
unsafe_delete!(LLVM.parent(call), call)
332-
end
333-
end
334-
end
335-
336-
# all deferred compilations should have been resolved
337-
@compiler_assert isempty(uses(dyn_marker)) job
338-
unsafe_delete!(ir, dyn_marker)
339-
end
340234

341235
if libraries
342236
# load the runtime outside of a timing block (because it recurses into the compiler)
@@ -433,8 +327,8 @@ const __llvm_initialized = Ref(false)
433327
# finish the module
434328
#
435329
# we want to finish the module after optimization, so we cannot do so
436-
# during deferred code generation. instead, process the deferred jobs
437-
# here.
330+
# during deferred code generation. Instead, process the merged module
331+
# from all the jobs here.
438332
if toplevel
439333
entry = finish_ir!(job, ir, entry)
440334

0 commit comments

Comments
 (0)