@@ -131,6 +131,8 @@ function codegen(output::Symbol, @nospecialize(job::CompilerJob);
131
131
end
132
132
133
133
# GPUCompiler intrinsic that marks deferred compilation
134
+ # In contrast to `deferred_codegen` this doesn't support arbitrary
135
+ # jobs as call targets.
134
136
function var"gpuc.deferred" end
135
137
136
138
# primitive mechanism for deferred compilation, for implementing CUDA dynamic parallelism.
@@ -188,12 +190,28 @@ const __llvm_initialized = Ref(false)
188
190
# since those modules have been finalized themselves, and we don't want to re-finalize.
189
191
entry = finish_module! (job, ir, entry)
190
192
193
+ function unwrap_constant (val)
194
+ while val isa ConstantExpr
195
+ if opcode (val) == LLVM. API. LLVMIntToPtr ||
196
+ opcode (val) == LLVM. API. LLVMBitCast ||
197
+ opcode (val) == LLVM. API. LLVMAddrSpaceCast
198
+ val = first (operands (val))
199
+ else
200
+ break
201
+ end
202
+ end
203
+ return val
204
+ end
205
+
191
206
# deferred code generation
192
207
has_deferred_jobs = ! only_entry && toplevel &&
193
- haskey (functions (ir), " deferred_codegen" )
208
+ (haskey (functions (ir), " deferred_codegen" ) ||
209
+ haskey (functions (ir), " gpuc.lookup" ))
210
+
194
211
jobs = Dict {CompilerJob, String} (job => entry_fn)
195
212
if has_deferred_jobs
196
- dyn_marker = functions (ir)[" deferred_codegen" ]
213
+ dyn_marker = haskey (functions (ir), " deferred_codegen" ) ? functions (ir)[" deferred_codegen" ] : nothing
214
+ dyn_marker_v2 = haskey (functions (ir), " gpuc.lookup" ) ? functions (ir)[" gpuc.lookup" ] : nothing
197
215
198
216
# iterative compilation (non-recursive)
199
217
changed = true
@@ -202,26 +220,40 @@ const __llvm_initialized = Ref(false)
202
220
203
221
# find deferred compiler
204
222
# TODO : recover this information earlier, from the Julia IR
223
+ # We can do this now with gpuc.lookup
205
224
worklist = Dict {CompilerJob, Vector{LLVM.CallInst}} ()
206
- for use in uses (dyn_marker)
207
- # decode the call
208
- call = user (use):: LLVM.CallInst
209
- id = convert (Int, first (operands (call)))
210
-
211
- global deferred_codegen_jobs
212
- dyn_val = deferred_codegen_jobs[id]
213
-
214
- # get a job in the appopriate world
215
- dyn_job = if dyn_val isa CompilerJob
216
- # trust that the user knows what they're doing
217
- dyn_val
218
- else
219
- ft, tt = dyn_val
220
- dyn_src = methodinstance (ft, tt, tls_world_age ())
221
- CompilerJob (dyn_src, job. config)
225
+ if dyn_marker != = nothing
226
+ for use in uses (dyn_marker)
227
+ # decode the call
228
+ call = user (use):: LLVM.CallInst
229
+ id = convert (Int, first (operands (call)))
230
+
231
+ global deferred_codegen_jobs
232
+ dyn_val = deferred_codegen_jobs[id]
233
+
234
+ # get a job in the appopriate world
235
+ dyn_job = if dyn_val isa CompilerJob
236
+ # trust that the user knows what they're doing
237
+ dyn_val
238
+ else
239
+ ft, tt = dyn_val
240
+ dyn_src = methodinstance (ft, tt, tls_world_age ())
241
+ CompilerJob (dyn_src, job. config)
242
+ end
243
+
244
+ push! (get! (worklist, dyn_job, LLVM. CallInst[]), call)
222
245
end
246
+ end
223
247
224
- push! (get! (worklist, dyn_job, LLVM. CallInst[]), call)
248
+ if dyn_marker_v2 != = nothing
249
+ for use in uses (dyn_marker_v2)
250
+ # decode the call
251
+ call = user (use):: LLVM.CallInst
252
+ dyn_mi = Base. unsafe_pointer_to_objref (
253
+ convert (Ptr{Cvoid}, convert (Int, unwrap_constant (operands (call)[1 ]))))
254
+ dyn_job = CompilerJob (dyn_mi, job. config)
255
+ push! (get! (worklist, dyn_job, LLVM. CallInst[]), call)
256
+ end
225
257
end
226
258
227
259
# compile and link
@@ -263,8 +295,15 @@ const __llvm_initialized = Ref(false)
263
295
end
264
296
265
297
# all deferred compilations should have been resolved
266
- @compiler_assert isempty (uses (dyn_marker)) job
267
- unsafe_delete! (ir, dyn_marker)
298
+ if dyn_marker != = nothing
299
+ @compiler_assert isempty (uses (dyn_marker)) job
300
+ unsafe_delete! (ir, dyn_marker)
301
+ end
302
+
303
+ if dyn_marker_v2 != = nothing
304
+ @compiler_assert isempty (uses (dyn_marker_v2)) job
305
+ unsafe_delete! (ir, dyn_marker_v2)
306
+ end
268
307
end
269
308
270
309
if toplevel
0 commit comments