@@ -39,6 +39,41 @@ function JuliaContext(f; kwargs...)
39
39
end
40
40
41
41
42
+ # # deferred compilation
43
+
44
+ function var"gpuc.deferred" end
45
+
46
+ # old, deprecated mechanism slated for removal once Enzyme is updated to the new intrinsic
47
+ begin
48
+ # primitive mechanism for deferred compilation, for implementing CUDA dynamic parallelism.
49
+ # this could both be generalized (e.g. supporting actual function calls, instead of
50
+ # returning a function pointer), and be integrated with the nonrecursive codegen.
51
+ const deferred_codegen_jobs = Dict {Int, Any} ()
52
+
53
+ # We make this function explicitly callable so that we can drive OrcJIT's
54
+ # lazy compilation from, while also enabling recursive compilation.
55
+ Base. @ccallable Ptr{Cvoid} function deferred_codegen (ptr:: Ptr{Cvoid} )
56
+ ptr
57
+ end
58
+
59
+ @generated function deferred_codegen (:: Val{ft} , :: Val{tt} ) where {ft,tt}
60
+ id = length (deferred_codegen_jobs) + 1
61
+ deferred_codegen_jobs[id] = (; ft, tt)
62
+ # don't bother looking up the method instance, as we'll do so again during codegen
63
+ # using the world age of the parent.
64
+ #
65
+ # this also works around an issue on <1.10, where we don't know the world age of
66
+ # generated functions so use the current world counter, which may be too new
67
+ # for the world we're compiling for.
68
+
69
+ quote
70
+ # TODO : add an edge to this method instance to support method redefinitions
71
+ ccall (" extern deferred_codegen" , llvmcall, Ptr{Cvoid}, (Int,), $ id)
72
+ end
73
+ end
74
+ end
75
+
76
+
42
77
# # compiler entrypoint
43
78
44
79
export compile
@@ -127,33 +162,6 @@ function codegen(output::Symbol, @nospecialize(job::CompilerJob); toplevel::Bool
127
162
error (" Unknown compilation output $output " )
128
163
end
129
164
130
- # primitive mechanism for deferred compilation, for implementing CUDA dynamic parallelism.
131
- # this could both be generalized (e.g. supporting actual function calls, instead of
132
- # returning a function pointer), and be integrated with the nonrecursive codegen.
133
- const deferred_codegen_jobs = Dict {Int, Any} ()
134
-
135
- # We make this function explicitly callable so that we can drive OrcJIT's
136
- # lazy compilation from, while also enabling recursive compilation.
137
- Base. @ccallable Ptr{Cvoid} function deferred_codegen (ptr:: Ptr{Cvoid} )
138
- ptr
139
- end
140
-
141
- @generated function deferred_codegen (:: Val{ft} , :: Val{tt} ) where {ft,tt}
142
- id = length (deferred_codegen_jobs) + 1
143
- deferred_codegen_jobs[id] = (; ft, tt)
144
- # don't bother looking up the method instance, as we'll do so again during codegen
145
- # using the world age of the parent.
146
- #
147
- # this also works around an issue on <1.10, where we don't know the world age of
148
- # generated functions so use the current world counter, which may be too new
149
- # for the world we're compiling for.
150
-
151
- quote
152
- # TODO : add an edge to this method instance to support method redefinitions
153
- ccall (" extern deferred_codegen" , llvmcall, Ptr{Cvoid}, (Int,), $ id)
154
- end
155
- end
156
-
157
165
const __llvm_initialized = Ref (false )
158
166
159
167
@locked function emit_llvm (@nospecialize (job:: CompilerJob ); toplevel:: Bool ,
@@ -183,9 +191,82 @@ const __llvm_initialized = Ref(false)
183
191
entry = finish_module! (job, ir, entry)
184
192
185
193
# deferred code generation
186
- has_deferred_jobs = toplevel && ! only_entry && haskey (functions (ir), " deferred_codegen" )
194
+ run_optimization_for_deferred = false
195
+ if haskey (functions (ir), " gpuc.lookup" )
196
+ run_optimization_for_deferred = true
197
+ dyn_marker = functions (ir)[" gpuc.lookup" ]
198
+
199
+ # gpuc.deferred is lowered to a gpuc.lookup foreigncall, so we need to extract the
200
+ # target method instance from the LLVM IR
201
+ # TODO : drive deferred compilation from the Julia IR instead
202
+ function find_base_object (val)
203
+ while true
204
+ if val isa ConstantExpr && (opcode (val) == LLVM. API. LLVMIntToPtr ||
205
+ opcode (val) == LLVM. API. LLVMBitCast ||
206
+ opcode (val) == LLVM. API. LLVMAddrSpaceCast)
207
+ val = first (operands (val))
208
+ elseif val isa LLVM. IntToPtrInst ||
209
+ val isa LLVM. BitCastInst ||
210
+ val isa LLVM. AddrSpaceCastInst
211
+ val = first (operands (val))
212
+ elseif val isa LLVM. LoadInst
213
+ # In 1.11+ we no longer embed integer constants directly.
214
+ gv = first (operands (val))
215
+ if gv isa LLVM. GlobalValue
216
+ val = LLVM. initializer (gv)
217
+ continue
218
+ end
219
+ break
220
+ else
221
+ break
222
+ end
223
+ end
224
+ return val
225
+ end
226
+
227
+ worklist = Dict {Any, Vector{LLVM.CallInst}} ()
228
+ for use in uses (dyn_marker)
229
+ # decode the call
230
+ call = user (use):: LLVM.CallInst
231
+ dyn_mi_inst = find_base_object (operands (call)[1 ])
232
+ @compiler_assert isa (dyn_mi_inst, LLVM. ConstantInt) job
233
+ dyn_mi = Base. unsafe_pointer_to_objref (
234
+ convert (Ptr{Cvoid}, convert (Int, dyn_mi_inst)))
235
+ push! (get! (worklist, dyn_mi, LLVM. CallInst[]), call)
236
+ end
237
+
238
+ for dyn_mi in keys (worklist)
239
+ dyn_fn_name = compiled[dyn_mi]. specfunc
240
+ dyn_fn = functions (ir)[dyn_fn_name]
241
+
242
+ # insert a pointer to the function everywhere the entry is used
243
+ T_ptr = convert (LLVMType, Ptr{Cvoid})
244
+ for call in worklist[dyn_mi]
245
+ @dispose builder= IRBuilder () begin
246
+ position! (builder, call)
247
+ fptr = if LLVM. version () >= v " 17"
248
+ T_ptr = LLVM. PointerType ()
249
+ bitcast! (builder, dyn_fn, T_ptr)
250
+ elseif VERSION >= v " 1.12.0-DEV.225"
251
+ T_ptr = LLVM. PointerType (LLVM. Int8Type ())
252
+ bitcast! (builder, dyn_fn, T_ptr)
253
+ else
254
+ ptrtoint! (builder, dyn_fn, T_ptr)
255
+ end
256
+ replace_uses! (call, fptr)
257
+ end
258
+ unsafe_delete! (LLVM. parent (call), call)
259
+ end
260
+ end
261
+
262
+ # all deferred compilations should have been resolved
263
+ @compiler_assert isempty (uses (dyn_marker)) job
264
+ unsafe_delete! (ir, dyn_marker)
265
+ end
266
+ # # old, deprecated implementation
187
267
jobs = Dict {CompilerJob, String} (job => entry_fn)
188
- if has_deferred_jobs
268
+ if toplevel && ! only_entry && haskey (functions (ir), " deferred_codegen" )
269
+ run_optimization_for_deferred = true
189
270
dyn_marker = functions (ir)[" deferred_codegen" ]
190
271
191
272
# iterative compilation (non-recursive)
@@ -194,7 +275,6 @@ const __llvm_initialized = Ref(false)
194
275
changed = false
195
276
196
277
# find deferred compiler
197
- # TODO : recover this information earlier, from the Julia IR
198
278
worklist = Dict {CompilerJob, Vector{LLVM.CallInst}} ()
199
279
for use in uses (dyn_marker)
200
280
# decode the call
@@ -317,7 +397,7 @@ const __llvm_initialized = Ref(false)
317
397
# deferred codegen has some special optimization requirements,
318
398
# which also need to happen _after_ regular optimization.
319
399
# XXX : make these part of the optimizer pipeline?
320
- if has_deferred_jobs
400
+ if run_optimization_for_deferred
321
401
@dispose pb= NewPMPassBuilder () begin
322
402
add! (pb, NewPMFunctionPassManager ()) do fpm
323
403
add! (fpm, InstCombinePass ())
0 commit comments