Skip to content

Commit 3e1e45f

Browse files
authored
Introduce a kernel state argument. (#236)
1 parent c40b3bb commit 3e1e45f

File tree

10 files changed

+391
-49
lines changed

10 files changed

+391
-49
lines changed

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@ UUIDs = "cf7118a7-6976-5b1a-9a39-7adc72f591a4"
1414

1515
[compat]
1616
ExprTools = "0.1"
17-
LLVM = "4.3"
17+
LLVM = "4.4"
1818
TimerOutputs = "0.5"
1919
julia = "1.6"
2020

src/driver.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -224,6 +224,8 @@ const __llvm_initialized = Ref(false)
224224
end
225225
end
226226

227+
entry = finish_module!(job, ir, entry)
228+
227229
# deferred code generation
228230
if !only_entry && deferred_codegen && haskey(functions(ir), "deferred_codegen")
229231
dyn_marker = functions(ir)["deferred_codegen"]
@@ -299,8 +301,6 @@ const __llvm_initialized = Ref(false)
299301
unsafe_delete!(ir, dyn_marker)
300302
end
301303

302-
finish_module!(job, ir)
303-
304304
return ir, (; entry, compiled)
305305
end
306306

src/gcn.jl

Lines changed: 18 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -44,11 +44,9 @@ function process_module!(job::CompilerJob{GCNCompilerTarget}, mod::LLVM.Module)
4444
end
4545

4646
function process_entry!(job::CompilerJob{GCNCompilerTarget}, mod::LLVM.Module, entry::LLVM.Function)
47-
invoke(process_entry!, Tuple{CompilerJob, LLVM.Module, LLVM.Function}, job, mod, entry)
47+
entry = invoke(process_entry!, Tuple{CompilerJob, LLVM.Module, LLVM.Function}, job, mod, entry)
4848

4949
if job.source.kernel
50-
entry = lower_byval(job, mod, entry)
51-
5250
# calling convention
5351
callconv!(entry, LLVM.API.LLVMAMDGPUKERNELCallConv)
5452
end
@@ -59,6 +57,7 @@ end
5957
function add_lowering_passes!(job::CompilerJob{GCNCompilerTarget}, pm::LLVM.PassManager)
6058
add!(pm, ModulePass("LowerThrowExtra", lower_throw_extra!))
6159
end
60+
6261
# We need to do alloca rewriting (from 0 to 5) after Julia's optimization
6362
# passes because of two reasons:
6463
# 1. Debug builds call the target verifier first, which would trip if AMDGPU
@@ -81,6 +80,21 @@ function optimize_module!(job::CompilerJob{GCNCompilerTarget}, mod::LLVM.Module)
8180
end
8281
end
8382

83+
function finish_module!(@nospecialize(job::CompilerJob{GCNCompilerTarget}),
84+
mod::LLVM.Module, entry::LLVM.Function)
85+
entry = invoke(finish_module!, Tuple{CompilerJob, LLVM.Module, LLVM.Function}, job, mod, entry)
86+
87+
if job.source.kernel
88+
# work around bad byval codegen (JuliaGPU/GPUCompiler.jl#92)
89+
entry = lower_byval(job, mod, entry)
90+
end
91+
92+
return entry
93+
end
94+
95+
96+
## LLVM passes
97+
8498
function lower_throw_extra!(mod::LLVM.Module)
8599
job = current_job::CompilerJob
86100
ctx = context(mod)
@@ -95,7 +109,6 @@ function lower_throw_extra!(mod::LLVM.Module)
95109
r"julia___subarray_throw_boundserror.*",
96110
]
97111

98-
99112
for f in functions(mod)
100113
f_name = LLVM.name(f)
101114
for fn in throw_functions
@@ -139,6 +152,7 @@ function lower_throw_extra!(mod::LLVM.Module)
139152
end
140153
return changed
141154
end
155+
142156
function fix_alloca_addrspace!(fn::LLVM.Function)
143157
changed = false
144158
alloca_as = 5
@@ -166,7 +180,6 @@ function fix_alloca_addrspace!(fn::LLVM.Function)
166180
return changed
167181
end
168182

169-
170183
function emit_trap!(job::CompilerJob{GCNCompilerTarget}, builder, mod, inst)
171184
ctx = context(mod)
172185
trap = if haskey(functions(mod), "llvm.trap")

src/interface.jl

Lines changed: 31 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -165,6 +165,13 @@ runtime_slug(@nospecialize(job::CompilerJob)) = error("Not implemented")
165165
# early processing of the newly generated LLVM IR module
166166
process_module!(@nospecialize(job::CompilerJob), mod::LLVM.Module) = return
167167

168+
# the type of the kernel state object, or Nothing if this back-end doesn't need one.
169+
#
170+
# the generated code will be rewritten to include an object of this type as the first
171+
# argument to each kernel, and pass that object to every function that accesses the kernel
172+
# state (possibly indirectly) via the `kernel_state_pointer` function.
173+
kernel_state_type(@nospecialize(job::CompilerJob)) = Nothing
174+
168175
# early processing of the newly identified LLVM kernel function
169176
function process_entry!(@nospecialize(job::CompilerJob), mod::LLVM.Module,
170177
entry::LLVM.Function)
@@ -173,7 +180,7 @@ function process_entry!(@nospecialize(job::CompilerJob), mod::LLVM.Module,
173180
if job.source.kernel
174181
# pass all bitstypes by value; by default Julia passes aggregates by reference
175182
# (this improves performance, and is mandated by certain back-ends like SPIR-V).
176-
args = classify_arguments(job, entry)
183+
args = classify_arguments(job, eltype(llvmtype(entry)))
177184
for arg in args
178185
if arg.cc == BITS_REF
179186
attr = if LLVM.version() >= v"12"
@@ -193,7 +200,29 @@ end
193200
optimize_module!(@nospecialize(job::CompilerJob), mod::LLVM.Module) = return
194201

195202
# final processing of the IR module, right before validation and machine-code generation
196-
finish_module!(@nospecialize(job::CompilerJob), mod::LLVM.Module) = return
203+
function finish_module!(@nospecialize(job::CompilerJob), mod::LLVM.Module, entry::LLVM.Function)
204+
ctx = context(mod)
205+
entry_fn = LLVM.name(entry)
206+
207+
# add the kernel state, and lower calls to the `julia.gpu.state_getter` intrinsic.
208+
# we do this _after_ optimization, because the runtime is linked after optimization too.
209+
if job.source.kernel
210+
state = kernel_state_type(job)
211+
if state !== Nothing
212+
T_state = convert(LLVMType, state; ctx)
213+
add_kernel_state!(job, mod, entry, T_state)
214+
end
215+
216+
# don't pass the state when unnecessary
217+
# XXX: only apply in add_kernel_state! when needed?
218+
ModulePassManager() do pm
219+
dead_arg_elimination!(pm)
220+
run!(pm, mod)
221+
end
222+
end
223+
224+
return functions(mod)[entry_fn]
225+
end
197226

198227
add_lowering_passes!(@nospecialize(job::CompilerJob), pm::LLVM.PassManager) = return
199228

0 commit comments

Comments
 (0)