@@ -204,14 +204,11 @@ const __llvm_initialized = Ref(false)
204
204
end
205
205
206
206
# deferred code generation
207
- has_deferred_jobs = ! only_entry && toplevel &&
208
- (haskey (functions (ir), " deferred_codegen" ) ||
209
- haskey (functions (ir), " gpuc.lookup" ))
207
+ has_deferred_jobs = ! only_entry && toplevel && haskey (functions (ir), " deferred_codegen" )
210
208
211
209
jobs = Dict {CompilerJob, String} (job => entry_fn)
212
210
if has_deferred_jobs
213
- dyn_marker = haskey (functions (ir), " deferred_codegen" ) ? functions (ir)[" deferred_codegen" ] : nothing
214
- dyn_marker_v2 = haskey (functions (ir), " gpuc.lookup" ) ? functions (ir)[" gpuc.lookup" ] : nothing
211
+ dyn_marker = functions (ir)[" deferred_codegen" ]
215
212
216
213
# iterative compilation (non-recursive)
217
214
changed = true
@@ -222,38 +219,25 @@ const __llvm_initialized = Ref(false)
222
219
# TODO : recover this information earlier, from the Julia IR
223
220
# We can do this now with gpuc.lookup
224
221
worklist = Dict {CompilerJob, Vector{LLVM.CallInst}} ()
225
- if dyn_marker != = nothing
226
- for use in uses (dyn_marker)
227
- # decode the call
228
- call = user (use):: LLVM.CallInst
229
- id = convert (Int, first (operands (call)))
230
-
231
- global deferred_codegen_jobs
232
- dyn_val = deferred_codegen_jobs[id]
233
-
234
- # get a job in the appopriate world
235
- dyn_job = if dyn_val isa CompilerJob
236
- # trust that the user knows what they're doing
237
- dyn_val
238
- else
239
- ft, tt = dyn_val
240
- dyn_src = methodinstance (ft, tt, tls_world_age ())
241
- CompilerJob (dyn_src, job. config)
242
- end
243
-
244
- push! (get! (worklist, dyn_job, LLVM. CallInst[]), call)
222
+ for use in uses (dyn_marker)
223
+ # decode the call
224
+ call = user (use):: LLVM.CallInst
225
+ id = convert (Int, first (operands (call)))
226
+
227
+ global deferred_codegen_jobs
228
+ dyn_val = deferred_codegen_jobs[id]
229
+
230
+ # get a job in the appopriate world
231
+ dyn_job = if dyn_val isa CompilerJob
232
+ # trust that the user knows what they're doing
233
+ dyn_val
234
+ else
235
+ ft, tt = dyn_val
236
+ dyn_src = methodinstance (ft, tt, tls_world_age ())
237
+ CompilerJob (dyn_src, job. config)
245
238
end
246
- end
247
239
248
- if dyn_marker_v2 != = nothing
249
- for use in uses (dyn_marker_v2)
250
- # decode the call
251
- call = user (use):: LLVM.CallInst
252
- dyn_mi = Base. unsafe_pointer_to_objref (
253
- convert (Ptr{Cvoid}, convert (Int, unwrap_constant (operands (call)[1 ]))))
254
- dyn_job = CompilerJob (dyn_mi, job. config)
255
- push! (get! (worklist, dyn_job, LLVM. CallInst[]), call)
256
- end
240
+ push! (get! (worklist, dyn_job, LLVM. CallInst[]), call)
257
241
end
258
242
259
243
# compile and link
@@ -299,11 +283,46 @@ const __llvm_initialized = Ref(false)
299
283
@compiler_assert isempty (uses (dyn_marker)) job
300
284
unsafe_delete! (ir, dyn_marker)
301
285
end
286
+ end
287
+
288
+ if haskey (functions (ir), " gpuc.lookup" )
289
+ dyn_marker = functions (ir)[" gpuc.lookup" ]
302
290
303
- if dyn_marker_v2 != = nothing
304
- @compiler_assert isempty (uses (dyn_marker_v2)) job
305
- unsafe_delete! (ir, dyn_marker_v2)
291
+ worklist = Dict {Any, Vector{LLVM.CallInst}} ()
292
+ for use in uses (dyn_marker)
293
+ # decode the call
294
+ call = user (use):: LLVM.CallInst
295
+ dyn_mi = Base. unsafe_pointer_to_objref (
296
+ convert (Ptr{Cvoid}, convert (Int, unwrap_constant (operands (call)[1 ]))))
297
+ push! (get! (worklist, dyn_mi, LLVM. CallInst[]), call)
306
298
end
299
+
300
+ for dyn_mi in keys (worklist)
301
+ dyn_fn_name = compiled[dyn_mi]. specfunc
302
+ dyn_fn = functions (ir)[dyn_fn_name]
303
+
304
+ # insert a pointer to the function everywhere the entry is used
305
+ T_ptr = convert (LLVMType, Ptr{Cvoid})
306
+ for call in worklist[dyn_mi]
307
+ @dispose builder= IRBuilder () begin
308
+ position! (builder, call)
309
+ fptr = if LLVM. version () >= v " 17"
310
+ T_ptr = LLVM. PointerType ()
311
+ bitcast! (builder, dyn_fn, T_ptr)
312
+ elseif VERSION >= v " 1.12.0-DEV.225"
313
+ T_ptr = LLVM. PointerType (LLVM. Int8Type ())
314
+ bitcast! (builder, dyn_fn, T_ptr)
315
+ else
316
+ ptrtoint! (builder, dyn_fn, T_ptr)
317
+ end
318
+ replace_uses! (call, fptr)
319
+ end
320
+ unsafe_delete! (LLVM. parent (call), call)
321
+ end
322
+ end
323
+
324
+ @compiler_assert isempty (uses (dyn_marker)) job
325
+ unsafe_delete! (ir, dyn_marker)
307
326
end
308
327
309
328
if toplevel
0 commit comments