Skip to content

Commit 7fa1f9e

Browse files
committed
Convert initial kernel state transformation to a late pass.
1 parent dbc6902 commit 7fa1f9e

File tree

3 files changed

+10
-12
lines changed

3 files changed

+10
-12
lines changed

src/interface.jl

Lines changed: 1 addition & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -233,15 +233,7 @@ optimize_module!(@nospecialize(job::CompilerJob), mod::LLVM.Module) = return
233233

234234
# finalization of the module, before deferred codegen and optimization
235235
function finish_module!(@nospecialize(job::CompilerJob), mod::LLVM.Module, entry::LLVM.Function)
236-
ctx = context(mod)
237-
entry_fn = LLVM.name(entry)
238-
239-
# add the kernel state, and lower calls to the `julia.gpu.state_getter` intrinsic.
240-
if job.source.kernel
241-
add_kernel_state!(job, mod, entry)
242-
end
243-
244-
return functions(mod)[entry_fn]
236+
return entry
245237
end
246238

247239
# final processing of the IR, right before validation and machine-code generation

src/irgen.jl

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -535,10 +535,9 @@ end
535535
# so that the julia.gpu.state_getter` can be simplified to return an opaque pointer.
536536

537537
# add a state argument to every function in the module, starting from the kernel entry point
538-
function add_kernel_state!(@nospecialize(job::CompilerJob), mod::LLVM.Module,
539-
entry::LLVM.Function)
538+
function add_kernel_state!(mod::LLVM.Module)
539+
job = current_job::CompilerJob
540540
ctx = context(mod)
541-
entry_fn = LLVM.name(entry)
542541

543542
# check if we even need a kernel state argument
544543
state = kernel_state_type(job)
@@ -552,6 +551,11 @@ function add_kernel_state!(@nospecialize(job::CompilerJob), mod::LLVM.Module,
552551
# this is both for extern uses, and to make this transformation a two-step process.
553552
state_intr = kernel_state_intr(mod, T_state)
554553

554+
entry_md = operands(metadata(mod)["julia.entry"])[1]
555+
entry = Value(operands(entry_md)[1]; ctx)
556+
# XXX: this metadata will be invalid after the replacement here (it'll be null).
557+
# how do we replace Metadata uses? Normally RAUW, but it asserts type equality
558+
555559
# determine which functions need a kernel state argument
556560
#
557561
# previously, we add the argument to every function and relied on unused arg elim to

src/optim.jl

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -192,6 +192,8 @@ function optimize!(@nospecialize(job::CompilerJob), mod::LLVM.Module)
192192
if job.source.kernel
193193
# GC lowering is the last pass that may introduce calls to the runtime library,
194194
# and thus additional uses of the kernel state intrinsic.
195+
# TODO: now that all kernel state-related passes are being run here, merge some?
196+
add!(pm, ModulePass("AddKernelState", add_kernel_state!))
195197
add!(pm, FunctionPass("LowerKernelState", lower_kernel_state!))
196198
add!(pm, ModulePass("CleanupKernelState", cleanup_kernel_state!))
197199
end

0 commit comments

Comments
 (0)