Skip to content

Commit 30da248

Browse files
committed
add worklist
1 parent ad1d6c6 commit 30da248

File tree

2 files changed

+83
-46
lines changed

2 files changed

+83
-46
lines changed

src/driver.jl

Lines changed: 57 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -237,14 +237,11 @@ const __llvm_initialized = Ref(false)
237237
end
238238

239239
# deferred code generation
240-
has_deferred_jobs = !only_entry && toplevel &&
241-
(haskey(functions(ir), "deferred_codegen") ||
242-
haskey(functions(ir), "gpuc.lookup"))
240+
has_deferred_jobs = !only_entry && toplevel && haskey(functions(ir), "deferred_codegen")
243241

244242
jobs = Dict{CompilerJob, String}(job => entry_fn)
245243
if has_deferred_jobs
246-
dyn_marker = haskey(functions(ir), "deferred_codegen") ? functions(ir)["deferred_codegen"] : nothing
247-
dyn_marker_v2 = haskey(functions(ir), "gpuc.lookup") ? functions(ir)["gpuc.lookup"] : nothing
244+
dyn_marker = functions(ir)["deferred_codegen"]
248245

249246
# iterative compilation (non-recursive)
250247
changed = true
@@ -255,38 +252,25 @@ const __llvm_initialized = Ref(false)
255252
# TODO: recover this information earlier, from the Julia IR
256253
# We can do this now with gpuc.lookup
257254
worklist = Dict{CompilerJob, Vector{LLVM.CallInst}}()
258-
if dyn_marker !== nothing
259-
for use in uses(dyn_marker)
260-
# decode the call
261-
call = user(use)::LLVM.CallInst
262-
id = convert(Int, first(operands(call)))
263-
264-
global deferred_codegen_jobs
265-
dyn_val = deferred_codegen_jobs[id]
266-
267-
# get a job in the appopriate world
268-
dyn_job = if dyn_val isa CompilerJob
269-
# trust that the user knows what they're doing
270-
dyn_val
271-
else
272-
ft, tt = dyn_val
273-
dyn_src = methodinstance(ft, tt, tls_world_age())
274-
CompilerJob(dyn_src, job.config)
275-
end
276-
277-
push!(get!(worklist, dyn_job, LLVM.CallInst[]), call)
255+
for use in uses(dyn_marker)
256+
# decode the call
257+
call = user(use)::LLVM.CallInst
258+
id = convert(Int, first(operands(call)))
259+
260+
global deferred_codegen_jobs
261+
dyn_val = deferred_codegen_jobs[id]
262+
263+
# get a job in the appopriate world
264+
dyn_job = if dyn_val isa CompilerJob
265+
# trust that the user knows what they're doing
266+
dyn_val
267+
else
268+
ft, tt = dyn_val
269+
dyn_src = methodinstance(ft, tt, tls_world_age())
270+
CompilerJob(dyn_src, job.config)
278271
end
279-
end
280272

281-
if dyn_marker_v2 !== nothing
282-
for use in uses(dyn_marker_v2)
283-
# decode the call
284-
call = user(use)::LLVM.CallInst
285-
dyn_mi = Base.unsafe_pointer_to_objref(
286-
convert(Ptr{Cvoid}, convert(Int, unwrap_constant(operands(call)[1]))))
287-
dyn_job = CompilerJob(dyn_mi, job.config)
288-
push!(get!(worklist, dyn_job, LLVM.CallInst[]), call)
289-
end
273+
push!(get!(worklist, dyn_job, LLVM.CallInst[]), call)
290274
end
291275

292276
# compile and link
@@ -332,11 +316,46 @@ const __llvm_initialized = Ref(false)
332316
@compiler_assert isempty(uses(dyn_marker)) job
333317
unsafe_delete!(ir, dyn_marker)
334318
end
319+
end
320+
321+
if haskey(functions(ir), "gpuc.lookup")
322+
dyn_marker = functions(ir)["gpuc.lookup"]
335323

336-
if dyn_marker_v2 !== nothing
337-
@compiler_assert isempty(uses(dyn_marker_v2)) job
338-
unsafe_delete!(ir, dyn_marker_v2)
324+
worklist = Dict{Any, Vector{LLVM.CallInst}}()
325+
for use in uses(dyn_marker)
326+
# decode the call
327+
call = user(use)::LLVM.CallInst
328+
dyn_mi = Base.unsafe_pointer_to_objref(
329+
convert(Ptr{Cvoid}, convert(Int, unwrap_constant(operands(call)[1]))))
330+
push!(get!(worklist, dyn_mi, LLVM.CallInst[]), call)
339331
end
332+
333+
for dyn_mi in keys(worklist)
334+
dyn_fn_name = compiled[dyn_mi].specfunc
335+
dyn_fn = functions(ir)[dyn_fn_name]
336+
337+
# insert a pointer to the function everywhere the entry is used
338+
T_ptr = convert(LLVMType, Ptr{Cvoid})
339+
for call in worklist[dyn_mi]
340+
@dispose builder=IRBuilder() begin
341+
position!(builder, call)
342+
fptr = if LLVM.version() >= v"17"
343+
T_ptr = LLVM.PointerType()
344+
bitcast!(builder, dyn_fn, T_ptr)
345+
elseif VERSION >= v"1.12.0-DEV.225"
346+
T_ptr = LLVM.PointerType(LLVM.Int8Type())
347+
bitcast!(builder, dyn_fn, T_ptr)
348+
else
349+
ptrtoint!(builder, dyn_fn, T_ptr)
350+
end
351+
replace_uses!(call, fptr)
352+
end
353+
unsafe_delete!(LLVM.parent(call), call)
354+
end
355+
end
356+
357+
@compiler_assert isempty(uses(dyn_marker)) job
358+
unsafe_delete!(ir, dyn_marker)
340359
end
341360

342361
if toplevel

src/jlgen.jl

Lines changed: 26 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -668,6 +668,24 @@ function compile_method_instance(@nospecialize(job::CompilerJob))
668668
error("Cannot compile $(job.source) for world $(job.world); method is only valid in worlds $(job.source.def.primary_world) to $(job.source.def.deleted_world)")
669669
end
670670

671+
compiled = IdDict()
672+
llvm_mod, outstanding = compile_method_instance(job, compiled)
673+
worklist = outstanding
674+
while !isempty(worklist)
675+
source = pop!(worklist)
676+
haskey(compiled, source) && continue
677+
job2 = CompilerJob(source, job.config)
678+
@debug "Processing..." job2
679+
llvm_mod2, outstanding = compile_method_instance(job2, compiled)
680+
append!(worklist, outstanding)
681+
@assert context(llvm_mod) == context(llvm_mod2)
682+
link!(llvm_mod, llvm_mod2)
683+
end
684+
685+
return llvm_mod, compiled
686+
end
687+
688+
function compile_method_instance(@nospecialize(job::CompilerJob), compiled::IdDict{Any, Any})
671689
# populate the cache
672690
interp = get_interpreter(job)
673691
cache = CC.code_cache(interp)
@@ -678,7 +696,7 @@ function compile_method_instance(@nospecialize(job::CompilerJob))
678696

679697
# create a callback to look-up function in our cache,
680698
# and keep track of the method instances we needed.
681-
method_instances = []
699+
method_instances = Any[]
682700
if Sys.ARCH == :x86 || Sys.ARCH == :x86_64
683701
function lookup_fun(mi, min_world, max_world)
684702
push!(method_instances, mi)
@@ -780,7 +798,6 @@ function compile_method_instance(@nospecialize(job::CompilerJob))
780798
end
781799

782800
# process all compiled method instances
783-
compiled = Dict()
784801
for mi in method_instances
785802
ci = ci_cache_lookup(cache, mi, job.world, job.world)
786803
ci === nothing && continue
@@ -819,7 +836,9 @@ function compile_method_instance(@nospecialize(job::CompilerJob))
819836

820837
# We don't control the interp that codegen constructs for us above.
821838
# So we have to scan the IR manually.
822-
for (mi, (ci::CodeInstance, _, _)) in compiled
839+
outstanding = Any[]
840+
for mi in method_instances
841+
ci = compiled[mi].ci
823842
src = @atomic :monotonic ci.inferred
824843
if src isa String
825844
src = Core.Compiler._uncompressed_ir(mi.def, src)
@@ -829,10 +848,9 @@ function compile_method_instance(@nospecialize(job::CompilerJob))
829848
if expr.head === :foreigncall &&
830849
expr.args[1] == "extern gpuc.lookup"
831850
deferred_mi = expr.args[6]
832-
# Now push to a worklist and process...
833-
# TODO: How do we deal with call duplication?
834-
# Can we codegen into the same module, or do we merge?
835-
# we can check against "compiled" to avoid recursion?
851+
if !haskey(compiled, deferred_mi)
852+
push!(outstanding, deferred_mi)
853+
end
836854
end
837855
end
838856
end
@@ -848,7 +866,7 @@ function compile_method_instance(@nospecialize(job::CompilerJob))
848866
end
849867
end
850868

851-
return llvm_mod, compiled
869+
return llvm_mod, outstanding
852870
end
853871

854872
# Narrow this if JuliaLang/julia#54069 get's backported to 1.11

0 commit comments

Comments
 (0)