@@ -237,14 +237,11 @@ const __llvm_initialized = Ref(false)
237
237
end
238
238
239
239
# deferred code generation
240
- has_deferred_jobs = ! only_entry && toplevel &&
241
- (haskey (functions (ir), " deferred_codegen" ) ||
242
- haskey (functions (ir), " gpuc.lookup" ))
240
+ has_deferred_jobs = ! only_entry && toplevel && haskey (functions (ir), " deferred_codegen" )
243
241
244
242
jobs = Dict {CompilerJob, String} (job => entry_fn)
245
243
if has_deferred_jobs
246
- dyn_marker = haskey (functions (ir), " deferred_codegen" ) ? functions (ir)[" deferred_codegen" ] : nothing
247
- dyn_marker_v2 = haskey (functions (ir), " gpuc.lookup" ) ? functions (ir)[" gpuc.lookup" ] : nothing
244
+ dyn_marker = functions (ir)[" deferred_codegen" ]
248
245
249
246
# iterative compilation (non-recursive)
250
247
changed = true
@@ -255,38 +252,25 @@ const __llvm_initialized = Ref(false)
255
252
# TODO : recover this information earlier, from the Julia IR
256
253
# We can do this now with gpuc.lookup
257
254
worklist = Dict {CompilerJob, Vector{LLVM.CallInst}} ()
258
- if dyn_marker != = nothing
259
- for use in uses (dyn_marker)
260
- # decode the call
261
- call = user (use):: LLVM.CallInst
262
- id = convert (Int, first (operands (call)))
263
-
264
- global deferred_codegen_jobs
265
- dyn_val = deferred_codegen_jobs[id]
266
-
267
- # get a job in the appopriate world
268
- dyn_job = if dyn_val isa CompilerJob
269
- # trust that the user knows what they're doing
270
- dyn_val
271
- else
272
- ft, tt = dyn_val
273
- dyn_src = methodinstance (ft, tt, tls_world_age ())
274
- CompilerJob (dyn_src, job. config)
275
- end
276
-
277
- push! (get! (worklist, dyn_job, LLVM. CallInst[]), call)
255
+ for use in uses (dyn_marker)
256
+ # decode the call
257
+ call = user (use):: LLVM.CallInst
258
+ id = convert (Int, first (operands (call)))
259
+
260
+ global deferred_codegen_jobs
261
+ dyn_val = deferred_codegen_jobs[id]
262
+
263
+ # get a job in the appopriate world
264
+ dyn_job = if dyn_val isa CompilerJob
265
+ # trust that the user knows what they're doing
266
+ dyn_val
267
+ else
268
+ ft, tt = dyn_val
269
+ dyn_src = methodinstance (ft, tt, tls_world_age ())
270
+ CompilerJob (dyn_src, job. config)
278
271
end
279
- end
280
272
281
- if dyn_marker_v2 != = nothing
282
- for use in uses (dyn_marker_v2)
283
- # decode the call
284
- call = user (use):: LLVM.CallInst
285
- dyn_mi = Base. unsafe_pointer_to_objref (
286
- convert (Ptr{Cvoid}, convert (Int, unwrap_constant (operands (call)[1 ]))))
287
- dyn_job = CompilerJob (dyn_mi, job. config)
288
- push! (get! (worklist, dyn_job, LLVM. CallInst[]), call)
289
- end
273
+ push! (get! (worklist, dyn_job, LLVM. CallInst[]), call)
290
274
end
291
275
292
276
# compile and link
@@ -332,11 +316,46 @@ const __llvm_initialized = Ref(false)
332
316
@compiler_assert isempty (uses (dyn_marker)) job
333
317
unsafe_delete! (ir, dyn_marker)
334
318
end
319
+ end
320
+
321
+ if haskey (functions (ir), " gpuc.lookup" )
322
+ dyn_marker = functions (ir)[" gpuc.lookup" ]
335
323
336
- if dyn_marker_v2 != = nothing
337
- @compiler_assert isempty (uses (dyn_marker_v2)) job
338
- unsafe_delete! (ir, dyn_marker_v2)
324
+ worklist = Dict {Any, Vector{LLVM.CallInst}} ()
325
+ for use in uses (dyn_marker)
326
+ # decode the call
327
+ call = user (use):: LLVM.CallInst
328
+ dyn_mi = Base. unsafe_pointer_to_objref (
329
+ convert (Ptr{Cvoid}, convert (Int, unwrap_constant (operands (call)[1 ]))))
330
+ push! (get! (worklist, dyn_mi, LLVM. CallInst[]), call)
339
331
end
332
+
333
+ for dyn_mi in keys (worklist)
334
+ dyn_fn_name = compiled[dyn_mi]. specfunc
335
+ dyn_fn = functions (ir)[dyn_fn_name]
336
+
337
+ # insert a pointer to the function everywhere the entry is used
338
+ T_ptr = convert (LLVMType, Ptr{Cvoid})
339
+ for call in worklist[dyn_mi]
340
+ @dispose builder= IRBuilder () begin
341
+ position! (builder, call)
342
+ fptr = if LLVM. version () >= v " 17"
343
+ T_ptr = LLVM. PointerType ()
344
+ bitcast! (builder, dyn_fn, T_ptr)
345
+ elseif VERSION >= v " 1.12.0-DEV.225"
346
+ T_ptr = LLVM. PointerType (LLVM. Int8Type ())
347
+ bitcast! (builder, dyn_fn, T_ptr)
348
+ else
349
+ ptrtoint! (builder, dyn_fn, T_ptr)
350
+ end
351
+ replace_uses! (call, fptr)
352
+ end
353
+ unsafe_delete! (LLVM. parent (call), call)
354
+ end
355
+ end
356
+
357
+ @compiler_assert isempty (uses (dyn_marker)) job
358
+ unsafe_delete! (ir, dyn_marker)
340
359
end
341
360
342
361
if toplevel
0 commit comments