Skip to content

Commit b04f2f2

Browse files
committed
add compiler support for gpuc.lookup
1 parent 440d6be commit b04f2f2

File tree

2 files changed

+63
-23
lines changed

2 files changed

+63
-23
lines changed

src/driver.jl

Lines changed: 60 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -131,6 +131,8 @@ function codegen(output::Symbol, @nospecialize(job::CompilerJob);
131131
end
132132

133133
# GPUCompiler intrinsic that marks deferred compilation
134+
# In contrast to `deferred_codegen` this doesn't support arbitrary
135+
# jobs as call targets.
134136
function var"gpuc.deferred" end
135137

136138
# primitive mechanism for deferred compilation, for implementing CUDA dynamic parallelism.
@@ -188,12 +190,28 @@ const __llvm_initialized = Ref(false)
188190
# since those modules have been finalized themselves, and we don't want to re-finalize.
189191
entry = finish_module!(job, ir, entry)
190192

193+
function unwrap_constant(val)
194+
while val isa ConstantExpr
195+
if opcode(val) == LLVM.API.LLVMIntToPtr ||
196+
opcode(val) == LLVM.API.LLVMBitCast ||
197+
opcode(val) == LLVM.API.LLVMAddrSpaceCast
198+
val = first(operands(val))
199+
else
200+
break
201+
end
202+
end
203+
return val
204+
end
205+
191206
# deferred code generation
192207
has_deferred_jobs = !only_entry && toplevel &&
193-
haskey(functions(ir), "deferred_codegen")
208+
(haskey(functions(ir), "deferred_codegen") ||
209+
haskey(functions(ir), "gpuc.lookup"))
210+
194211
jobs = Dict{CompilerJob, String}(job => entry_fn)
195212
if has_deferred_jobs
196-
dyn_marker = functions(ir)["deferred_codegen"]
213+
dyn_marker = haskey(functions(ir), "deferred_codegen") ? functions(ir)["deferred_codegen"] : nothing
214+
dyn_marker_v2 = haskey(functions(ir), "gpuc.lookup") ? functions(ir)["gpuc.lookup"] : nothing
197215

198216
# iterative compilation (non-recursive)
199217
changed = true
@@ -202,26 +220,40 @@ const __llvm_initialized = Ref(false)
202220

203221
# find deferred compiler
204222
# TODO: recover this information earlier, from the Julia IR
223+
# We can do this now with gpuc.lookup
205224
worklist = Dict{CompilerJob, Vector{LLVM.CallInst}}()
206-
for use in uses(dyn_marker)
207-
# decode the call
208-
call = user(use)::LLVM.CallInst
209-
id = convert(Int, first(operands(call)))
210-
211-
global deferred_codegen_jobs
212-
dyn_val = deferred_codegen_jobs[id]
213-
214-
# get a job in the appopriate world
215-
dyn_job = if dyn_val isa CompilerJob
216-
# trust that the user knows what they're doing
217-
dyn_val
218-
else
219-
ft, tt = dyn_val
220-
dyn_src = methodinstance(ft, tt, tls_world_age())
221-
CompilerJob(dyn_src, job.config)
225+
if dyn_marker !== nothing
226+
for use in uses(dyn_marker)
227+
# decode the call
228+
call = user(use)::LLVM.CallInst
229+
id = convert(Int, first(operands(call)))
230+
231+
global deferred_codegen_jobs
232+
dyn_val = deferred_codegen_jobs[id]
233+
234+
# get a job in the appopriate world
235+
dyn_job = if dyn_val isa CompilerJob
236+
# trust that the user knows what they're doing
237+
dyn_val
238+
else
239+
ft, tt = dyn_val
240+
dyn_src = methodinstance(ft, tt, tls_world_age())
241+
CompilerJob(dyn_src, job.config)
242+
end
243+
244+
push!(get!(worklist, dyn_job, LLVM.CallInst[]), call)
222245
end
246+
end
223247

224-
push!(get!(worklist, dyn_job, LLVM.CallInst[]), call)
248+
if dyn_marker_v2 !== nothing
249+
for use in uses(dyn_marker_v2)
250+
# decode the call
251+
call = user(use)::LLVM.CallInst
252+
dyn_mi = Base.unsafe_pointer_to_objref(
253+
convert(Ptr{Cvoid}, convert(Int, unwrap_constant(operands(call)[1]))))
254+
dyn_job = CompilerJob(dyn_mi, job.config)
255+
push!(get!(worklist, dyn_job, LLVM.CallInst[]), call)
256+
end
225257
end
226258

227259
# compile and link
@@ -263,8 +295,15 @@ const __llvm_initialized = Ref(false)
263295
end
264296

265297
# all deferred compilations should have been resolved
266-
@compiler_assert isempty(uses(dyn_marker)) job
267-
unsafe_delete!(ir, dyn_marker)
298+
if dyn_marker !== nothing
299+
@compiler_assert isempty(uses(dyn_marker)) job
300+
unsafe_delete!(ir, dyn_marker)
301+
end
302+
303+
if dyn_marker_v2 !== nothing
304+
@compiler_assert isempty(uses(dyn_marker_v2)) job
305+
unsafe_delete!(ir, dyn_marker_v2)
306+
end
268307
end
269308

270309
if toplevel

src/jlgen.jl

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -318,7 +318,8 @@ else
318318
get_method_table_view(world::UInt, mt::MTType) = OverlayMethodTable(world, mt)
319319
end
320320

321-
struct GPUInterpreter <: CC.AbstractInterpreter
321+
abstract type AbstractGPUInterpreter <: CC.AbstractInterpreter end
322+
struct GPUInterpreter <: AbstractGPUInterpreter
322323
world::UInt
323324
method_table::GPUMethodTableView
324325

@@ -440,7 +441,7 @@ struct DeferredCallInfo <: CC.CallInfo
440441
info::CC.CallInfo
441442
end
442443

443-
function CC.abstract_call_known(interp::GPUInterpreter, @nospecialize(f),
444+
function CC.abstract_call_known(interp::AbstractGPUInterpreter, @nospecialize(f),
444445
arginfo::CC.ArgInfo, si::CC.StmtInfo, sv::CC.AbsIntState,
445446
max_methods::Int = CC.get_max_methods(interp, f, sv))
446447
(; fargs, argtypes) = arginfo

0 commit comments

Comments
 (0)