Skip to content

Commit 4a11bd1

Browse files
authored
Merge pull request #249 from JuliaGPU/tb/pass_state_ccall
Fixes for late optimization of kernel state arguments
2 parents f4e9137 + 56ec519 commit 4a11bd1

File tree

8 files changed

+316
-93
lines changed

8 files changed

+316
-93
lines changed

src/driver.jl

Lines changed: 54 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -179,6 +179,25 @@ const __llvm_initialized = Ref(false)
179179
end
180180
end
181181

182+
# mark everything internal except for the entry and any exported global variables.
183+
# this makes sure that the optimizer can, e.g., touch function signatures.
184+
ModulePassManager() do pm
185+
# NOTE: this needs to happen after linking libraries to remove unused functions,
186+
# but before deferred codegen so that all kernels remain available.
187+
exports = String[entry_fn]
188+
for gvar in globals(ir)
189+
if linkage(gvar) == LLVM.API.LLVMExternalLinkage
190+
push!(exports, LLVM.name(gvar))
191+
end
192+
end
193+
internalize!(pm, exports)
194+
run!(pm, ir)
195+
end
196+
197+
# finalize the current module. this needs to happen before linking deferred modules,
198+
# since those modules have been finalized themselves, and we don't want to re-finalize.
199+
entry = finish_module!(job, ir, entry)
200+
182201
# deferred code generation
183202
do_deferred_codegen = !only_entry && deferred_codegen &&
184203
haskey(functions(ir), "deferred_codegen")
@@ -242,39 +261,9 @@ const __llvm_initialized = Ref(false)
242261
end
243262

244263
@timeit_debug to "IR post-processing" begin
245-
entry = finish_module!(job, ir, entry)
246-
247-
if optimize
248-
@timeit_debug to "optimization" optimize!(job, ir)
249-
250-
# optimization may have replaced functions, so look the entry point up again
251-
entry = functions(ir)[entry_fn]
252-
end
253-
254-
if ccall(:jl_is_debugbuild, Cint, ()) == 1
255-
@timeit_debug to "verification" verify(ir)
256-
end
257-
264+
# some early clean-up to reduce the amount of code to optimize
258265
@timeit_debug to "clean-up" begin
259-
# replace non-entry function definitions with a declaration
260-
if only_entry
261-
for f in functions(ir)
262-
f == entry && continue
263-
isdeclaration(f) && continue
264-
LLVM.isintrinsic(f) && continue
265-
empty!(f)
266-
end
267-
end
268-
269-
# remove everything except for the entry and any exported global variables
270-
exports = String[entry_fn]
271-
for gvar in globals(ir)
272-
push!(exports, LLVM.name(gvar))
273-
end
274-
275266
ModulePassManager() do pm
276-
internalize!(pm, exports)
277-
278267
# eliminate all unused internal functions
279268
global_optimizer!(pm)
280269
global_dce!(pm)
@@ -283,9 +272,20 @@ const __llvm_initialized = Ref(false)
283272
# merge constants (such as exception messages)
284273
constant_merge!(pm)
285274

286-
if do_deferred_codegen
287-
# inline and optimize the call to the deferred code. in particular we want to
288-
# remove unnecessary alloca's that are created by pass-by-ref semantics.
275+
run!(pm, ir)
276+
end
277+
end
278+
279+
if optimize
280+
@timeit_debug to "optimization" begin
281+
optimize!(job, ir)
282+
283+
# deferred codegen has some special optimization requirements,
284+
# which also need to happen _after_ regular optimization.
285+
# XXX: make these part of the optimizer pipeline?
286+
do_deferred_codegen && ModulePassManager() do pm
287+
# inline and optimize the call to e deferred code. in particular we want
288+
# to remove unnecessary alloca's created by pass-by-ref semantics.
289289
instruction_combining!(pm)
290290
always_inliner!(pm)
291291
scalar_repl_aggregates_ssa!(pm)
@@ -295,11 +295,30 @@ const __llvm_initialized = Ref(false)
295295
# merge duplicate functions, since each compilation invocation emits everything
296296
# XXX: ideally we want to avoid emitting these in the first place
297297
merge_functions!(pm)
298+
299+
run!(pm, ir)
298300
end
301+
end
299302

300-
run!(pm, ir)
303+
# optimization may have replaced functions, so look the entry point up again
304+
entry = functions(ir)[entry_fn]
305+
end
306+
307+
# replace non-entry function definitions with a declaration
308+
# NOTE: we can't do this before optimization, because the definitions of called
309+
# functions may affect optimization.
310+
if only_entry
311+
for f in functions(ir)
312+
f == entry && continue
313+
isdeclaration(f) && continue
314+
LLVM.isintrinsic(f) && continue
315+
empty!(f)
301316
end
302317
end
318+
319+
if ccall(:jl_is_debugbuild, Cint, ()) == 1
320+
@timeit_debug to "verification" verify(ir)
321+
end
303322
end
304323

305324
return ir, (; entry, compiled)

src/interface.jl

Lines changed: 1 addition & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -205,20 +205,8 @@ function finish_module!(@nospecialize(job::CompilerJob), mod::LLVM.Module, entry
205205
entry_fn = LLVM.name(entry)
206206

207207
# add the kernel state, and lower calls to the `julia.gpu.state_getter` intrinsic.
208-
# we do this _after_ optimization, because the runtime is linked after optimization too.
209208
if job.source.kernel
210-
state = kernel_state_type(job)
211-
if state !== Nothing
212-
T_state = convert(LLVMType, state; ctx)
213-
add_kernel_state!(job, mod, entry, T_state)
214-
end
215-
216-
# don't pass the state when unnecessary
217-
# XXX: only apply in add_kernel_state! when needed?
218-
ModulePassManager() do pm
219-
dead_arg_elimination!(pm)
220-
run!(pm, mod)
221-
end
209+
add_kernel_state!(job, mod, entry)
222210
end
223211

224212
return functions(mod)[entry_fn]

0 commit comments

Comments
 (0)