Skip to content

Commit 36abc73

Browse files
committed
add worklist
1 parent 448a8e1 commit 36abc73

File tree

2 files changed

+83
-46
lines changed

2 files changed

+83
-46
lines changed

src/driver.jl

Lines changed: 57 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -204,14 +204,11 @@ const __llvm_initialized = Ref(false)
204204
end
205205

206206
# deferred code generation
207-
has_deferred_jobs = !only_entry && toplevel &&
208-
(haskey(functions(ir), "deferred_codegen") ||
209-
haskey(functions(ir), "gpuc.lookup"))
207+
has_deferred_jobs = !only_entry && toplevel && haskey(functions(ir), "deferred_codegen")
210208

211209
jobs = Dict{CompilerJob, String}(job => entry_fn)
212210
if has_deferred_jobs
213-
dyn_marker = haskey(functions(ir), "deferred_codegen") ? functions(ir)["deferred_codegen"] : nothing
214-
dyn_marker_v2 = haskey(functions(ir), "gpuc.lookup") ? functions(ir)["gpuc.lookup"] : nothing
211+
dyn_marker = functions(ir)["deferred_codegen"]
215212

216213
# iterative compilation (non-recursive)
217214
changed = true
@@ -222,38 +219,25 @@ const __llvm_initialized = Ref(false)
222219
# TODO: recover this information earlier, from the Julia IR
223220
# We can do this now with gpuc.lookup
224221
worklist = Dict{CompilerJob, Vector{LLVM.CallInst}}()
225-
if dyn_marker !== nothing
226-
for use in uses(dyn_marker)
227-
# decode the call
228-
call = user(use)::LLVM.CallInst
229-
id = convert(Int, first(operands(call)))
230-
231-
global deferred_codegen_jobs
232-
dyn_val = deferred_codegen_jobs[id]
233-
234-
# get a job in the appopriate world
235-
dyn_job = if dyn_val isa CompilerJob
236-
# trust that the user knows what they're doing
237-
dyn_val
238-
else
239-
ft, tt = dyn_val
240-
dyn_src = methodinstance(ft, tt, tls_world_age())
241-
CompilerJob(dyn_src, job.config)
242-
end
243-
244-
push!(get!(worklist, dyn_job, LLVM.CallInst[]), call)
222+
for use in uses(dyn_marker)
223+
# decode the call
224+
call = user(use)::LLVM.CallInst
225+
id = convert(Int, first(operands(call)))
226+
227+
global deferred_codegen_jobs
228+
dyn_val = deferred_codegen_jobs[id]
229+
230+
# get a job in the appopriate world
231+
dyn_job = if dyn_val isa CompilerJob
232+
# trust that the user knows what they're doing
233+
dyn_val
234+
else
235+
ft, tt = dyn_val
236+
dyn_src = methodinstance(ft, tt, tls_world_age())
237+
CompilerJob(dyn_src, job.config)
245238
end
246-
end
247239

248-
if dyn_marker_v2 !== nothing
249-
for use in uses(dyn_marker_v2)
250-
# decode the call
251-
call = user(use)::LLVM.CallInst
252-
dyn_mi = Base.unsafe_pointer_to_objref(
253-
convert(Ptr{Cvoid}, convert(Int, unwrap_constant(operands(call)[1]))))
254-
dyn_job = CompilerJob(dyn_mi, job.config)
255-
push!(get!(worklist, dyn_job, LLVM.CallInst[]), call)
256-
end
240+
push!(get!(worklist, dyn_job, LLVM.CallInst[]), call)
257241
end
258242

259243
# compile and link
@@ -299,11 +283,46 @@ const __llvm_initialized = Ref(false)
299283
@compiler_assert isempty(uses(dyn_marker)) job
300284
unsafe_delete!(ir, dyn_marker)
301285
end
286+
end
287+
288+
if haskey(functions(ir), "gpuc.lookup")
289+
dyn_marker = functions(ir)["gpuc.lookup"]
302290

303-
if dyn_marker_v2 !== nothing
304-
@compiler_assert isempty(uses(dyn_marker_v2)) job
305-
unsafe_delete!(ir, dyn_marker_v2)
291+
worklist = Dict{Any, Vector{LLVM.CallInst}}()
292+
for use in uses(dyn_marker)
293+
# decode the call
294+
call = user(use)::LLVM.CallInst
295+
dyn_mi = Base.unsafe_pointer_to_objref(
296+
convert(Ptr{Cvoid}, convert(Int, unwrap_constant(operands(call)[1]))))
297+
push!(get!(worklist, dyn_mi, LLVM.CallInst[]), call)
306298
end
299+
300+
for dyn_mi in keys(worklist)
301+
dyn_fn_name = compiled[dyn_mi].specfunc
302+
dyn_fn = functions(ir)[dyn_fn_name]
303+
304+
# insert a pointer to the function everywhere the entry is used
305+
T_ptr = convert(LLVMType, Ptr{Cvoid})
306+
for call in worklist[dyn_mi]
307+
@dispose builder=IRBuilder() begin
308+
position!(builder, call)
309+
fptr = if LLVM.version() >= v"17"
310+
T_ptr = LLVM.PointerType()
311+
bitcast!(builder, dyn_fn, T_ptr)
312+
elseif VERSION >= v"1.12.0-DEV.225"
313+
T_ptr = LLVM.PointerType(LLVM.Int8Type())
314+
bitcast!(builder, dyn_fn, T_ptr)
315+
else
316+
ptrtoint!(builder, dyn_fn, T_ptr)
317+
end
318+
replace_uses!(call, fptr)
319+
end
320+
unsafe_delete!(LLVM.parent(call), call)
321+
end
322+
end
323+
324+
@compiler_assert isempty(uses(dyn_marker)) job
325+
unsafe_delete!(ir, dyn_marker)
307326
end
308327

309328
if toplevel

src/jlgen.jl

Lines changed: 26 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -639,6 +639,24 @@ function compile_method_instance(@nospecialize(job::CompilerJob))
639639
error("Cannot compile $(job.source) for world $(job.world); method is only valid in worlds $(job.source.def.primary_world) to $(job.source.def.deleted_world)")
640640
end
641641

642+
compiled = IdDict()
643+
llvm_mod, outstanding = compile_method_instance(job, compiled)
644+
worklist = outstanding
645+
while !isempty(worklist)
646+
source = pop!(worklist)
647+
haskey(compiled, source) && continue
648+
job2 = CompilerJob(source, job.config)
649+
@debug "Processing..." job2
650+
llvm_mod2, outstanding = compile_method_instance(job2, compiled)
651+
append!(worklist, outstanding)
652+
@assert context(llvm_mod) == context(llvm_mod2)
653+
link!(llvm_mod, llvm_mod2)
654+
end
655+
656+
return llvm_mod, compiled
657+
end
658+
659+
function compile_method_instance(@nospecialize(job::CompilerJob), compiled::IdDict{Any, Any})
642660
# populate the cache
643661
interp = get_interpreter(job)
644662
cache = CC.code_cache(interp)
@@ -649,7 +667,7 @@ function compile_method_instance(@nospecialize(job::CompilerJob))
649667

650668
# create a callback to look-up function in our cache,
651669
# and keep track of the method instances we needed.
652-
method_instances = []
670+
method_instances = Any[]
653671
if Sys.ARCH == :x86 || Sys.ARCH == :x86_64
654672
function lookup_fun(mi, min_world, max_world)
655673
push!(method_instances, mi)
@@ -714,7 +732,6 @@ function compile_method_instance(@nospecialize(job::CompilerJob))
714732
end
715733

716734
# process all compiled method instances
717-
compiled = Dict()
718735
for mi in method_instances
719736
ci = ci_cache_lookup(cache, mi, job.world, job.world)
720737
ci === nothing && continue
@@ -753,7 +770,9 @@ function compile_method_instance(@nospecialize(job::CompilerJob))
753770

754771
# We don't control the interp that codegen constructs for us above.
755772
# So we have to scan the IR manually.
756-
for (mi, (ci::CodeInstance, _, _)) in compiled
773+
outstanding = Any[]
774+
for mi in method_instances
775+
ci = compiled[mi].ci
757776
src = @atomic :monotonic ci.inferred
758777
if src isa String
759778
src = Core.Compiler._uncompressed_ir(mi.def, src)
@@ -763,18 +782,17 @@ function compile_method_instance(@nospecialize(job::CompilerJob))
763782
if expr.head === :foreigncall &&
764783
expr.args[1] == "extern gpuc.lookup"
765784
deferred_mi = expr.args[6]
766-
# Now push to a worklist and process...
767-
# TODO: How do we deal with call duplication?
768-
# Can we codegen into the same module, or do we merge?
769-
# we can check against "compiled" to avoid recursion?
785+
if !haskey(compiled, deferred_mi)
786+
push!(outstanding, deferred_mi)
787+
end
770788
end
771789
end
772790
end
773791

774792
# ensure that the requested method instance was compiled
775793
@assert haskey(compiled, job.source)
776794

777-
return llvm_mod, compiled
795+
return llvm_mod, outstanding
778796
end
779797

780798
# partially revert JuliaLangjulia#49391

0 commit comments

Comments
 (0)