Skip to content

Commit 36488fb

Browse files
authored
Merge pull request #327 from JuliaGPU/tb/late_kernel_state
Perform all kernel-state transformations late
2 parents a393fff + d9b1b08 commit 36488fb

File tree

10 files changed

+65
-79
lines changed

10 files changed

+65
-79
lines changed

Manifest.toml

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -39,15 +39,15 @@ version = "1.4.1"
3939

4040
[[LLVM]]
4141
deps = ["CEnum", "LLVMExtra_jll", "Libdl", "Printf", "Unicode"]
42-
git-tree-sha1 = "c9b86064be5ae0f63e50816a5a90b08c474507ae"
42+
git-tree-sha1 = "dd58421009014ff1ffacaa0db2a9a392114d75ee"
4343
uuid = "929cbde3-209d-540e-8aea-75f648917ca0"
44-
version = "4.9.1"
44+
version = "4.11.0"
4545

4646
[[LLVMExtra_jll]]
47-
deps = ["Artifacts", "JLLWrappers", "LazyArtifacts", "Libdl", "Pkg"]
48-
git-tree-sha1 = "5558ad3c8972d602451efe9d81c78ec14ef4f5ef"
47+
deps = ["Artifacts", "JLLWrappers", "LazyArtifacts", "Libdl", "Pkg", "TOML"]
48+
git-tree-sha1 = "771bfe376249626d3ca12bcd58ba243d3f961576"
4949
uuid = "dad2f222-ce93-54a1-a47d-0025e8a3acab"
50-
version = "0.0.14+2"
50+
version = "0.0.16+0"
5151

5252
[[LazyArtifacts]]
5353
deps = ["Artifacts", "Pkg"]

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,6 @@ UUIDs = "cf7118a7-6976-5b1a-9a39-7adc72f591a4"
1414

1515
[compat]
1616
ExprTools = "0.1"
17-
LLVM = "4.8"
17+
LLVM = "4.11"
1818
TimerOutputs = "0.5"
1919
julia = "1.6"

src/driver.jl

Lines changed: 22 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -257,11 +257,10 @@ const __llvm_initialized = Ref(false)
257257
# deferred code generation
258258
do_deferred_codegen = !only_entry && deferred_codegen &&
259259
haskey(functions(ir), "deferred_codegen")
260+
deferred_jobs = Dict{CompilerJob, String}(job => entry_fn)
260261
if do_deferred_codegen
261262
dyn_marker = functions(ir)["deferred_codegen"]
262263

263-
cache = Dict{CompilerJob, String}(job => entry_fn)
264-
265264
# iterative compilation (non-recursive)
266265
changed = true
267266
while changed
@@ -286,7 +285,7 @@ const __llvm_initialized = Ref(false)
286285
# compile and link
287286
for dyn_job in keys(worklist)
288287
# cached compilation
289-
dyn_entry_fn = get!(cache, dyn_job) do
288+
dyn_entry_fn = get!(deferred_jobs, dyn_job) do
290289
dyn_ir, dyn_meta = codegen(:llvm, dyn_job; optimize=false,
291290
deferred_codegen=false, parent_job=job, ctx)
292291
dyn_entry_fn = LLVM.name(dyn_meta.entry)
@@ -317,6 +316,14 @@ const __llvm_initialized = Ref(false)
317316
end
318317

319318
@timeit_debug to "IR post-processing" begin
319+
# mark the kernel entry-point functions (optimization may need it)
320+
if job.source.kernel
321+
push!(metadata(ir)["julia.kernel"], MDNode([entry]; ctx))
322+
323+
# IDEA: save all jobs, not only kernels, and save other attributes
324+
# so that we can reconstruct the CompileJob instead of setting it globally
325+
end
326+
320327
if optimize
321328
@timeit_debug to "optimization" begin
322329
optimize!(job, ir)
@@ -361,7 +368,18 @@ const __llvm_initialized = Ref(false)
361368
end
362369
end
363370

364-
entry = finish_ir!(job, ir, entry)
371+
# finish the module
372+
#
373+
# we want to finish the module after optimization, so we cannot do so during
374+
# deferred code generation. instead, process the deferred jobs here.
375+
if deferred_codegen
376+
entry = finish_ir!(job, ir, entry)
377+
378+
for (deferred_job, deferred_fn) in deferred_jobs
379+
deferred_job == job && continue
380+
finish_ir!(deferred_job, ir, functions(ir)[deferred_fn])
381+
end
382+
end
365383

366384
# replace non-entry function definitions with a declaration
367385
# NOTE: we can't do this before optimization, because the definitions of called

src/interface.jl

Lines changed: 1 addition & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -233,15 +233,7 @@ optimize_module!(@nospecialize(job::CompilerJob), mod::LLVM.Module) = return
233233

234234
# finalization of the module, before deferred codegen and optimization
235235
function finish_module!(@nospecialize(job::CompilerJob), mod::LLVM.Module, entry::LLVM.Function)
236-
ctx = context(mod)
237-
entry_fn = LLVM.name(entry)
238-
239-
# add the kernel state, and lower calls to the `julia.gpu.state_getter` intrinsic.
240-
if job.source.kernel
241-
add_kernel_state!(job, mod, entry)
242-
end
243-
244-
return functions(mod)[entry_fn]
236+
return entry
245237
end
246238

247239
# final processing of the IR, right before validation and machine-code generation

src/irgen.jl

Lines changed: 16 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -404,21 +404,14 @@ function lower_byval(@nospecialize(job::CompilerJob), mod::LLVM.Module, f::LLVM.
404404
# XXX: byval is not round-trippable on LLVM < 12 (see maleadt/LLVM.jl#186)
405405
# so we need to re-classify the Julia arguments.
406406
# remove this once we only support 1.7.
407-
has_kernel_state = kernel_state_type(job) !== Nothing
408-
orig_ft = if has_kernel_state
409-
# the kernel state has been added here already, so strip the first parameter
410-
LLVM.FunctionType(LLVM.return_type(ft), parameters(ft)[2:end]; vararg=isvararg(ft))
411-
else
412-
ft
413-
end
414-
args = classify_arguments(job, orig_ft)
407+
args = classify_arguments(job, ft)
415408
filter!(args) do arg
416409
arg.cc != GHOST
417410
end
418411
for arg in args
419412
if arg.cc == BITS_REF
420413
# NOTE: +1 since this pass runs after introducing the kernel state
421-
byval[arg.codegen.i+has_kernel_state] = true
414+
byval[arg.codegen.i] = true
422415
end
423416
end
424417
end
@@ -510,6 +503,7 @@ function lower_byval(@nospecialize(job::CompilerJob), mod::LLVM.Module, f::LLVM.
510503
# NOTE: if we ever have legitimate uses of the old function, create a shim instead
511504
fn = LLVM.name(f)
512505
@assert isempty(uses(f))
506+
replace_metadata_uses!(f, new_f)
513507
unsafe_delete!(mod, f)
514508
LLVM.name!(new_f, fn)
515509

@@ -535,10 +529,9 @@ end
535529
# so that the julia.gpu.state_getter` can be simplified to return an opaque pointer.
536530

537531
# add a state argument to every function in the module, starting from the kernel entry point
538-
function add_kernel_state!(@nospecialize(job::CompilerJob), mod::LLVM.Module,
539-
entry::LLVM.Function)
532+
function add_kernel_state!(mod::LLVM.Module)
533+
job = current_job::CompilerJob
540534
ctx = context(mod)
541-
entry_fn = LLVM.name(entry)
542535

543536
# check if we even need a kernel state argument
544537
state = kernel_state_type(job)
@@ -552,12 +545,18 @@ function add_kernel_state!(@nospecialize(job::CompilerJob), mod::LLVM.Module,
552545
# this is both for extern uses, and to make this transformation a two-step process.
553546
state_intr = kernel_state_intr(mod, T_state)
554547

548+
kernels = []
549+
kernels_md = metadata(mod)["julia.kernel"]
550+
for kernel_md in operands(kernels_md)
551+
push!(kernels, Value(operands(kernel_md)[1]; ctx))
552+
end
553+
555554
# determine which functions need a kernel state argument
556555
#
557556
# previously, we add the argument to every function and relied on unused arg elim to
558557
# clean-up the IR. however, some libraries do Funny Stuff, e.g., libdevice bitcasting
559558
# function pointers. such IR is hard to rewrite, so instead be more conservative.
560-
worklist = Set{LLVM.Function}([entry, state_intr])
559+
worklist = Set{LLVM.Function}([state_intr, kernels...])
561560
worklist_length = 0
562561
while worklist_length != length(worklist)
563562
# iteratively discover functions that use the intrinsic or any function calling it
@@ -669,6 +668,7 @@ function add_kernel_state!(@nospecialize(job::CompilerJob), mod::LLVM.Module,
669668
error("old function still has uses")
670669
end
671670
end
671+
replace_metadata_uses!(f, workmap[f])
672672
unsafe_delete!(mod, f)
673673
end
674674

@@ -707,10 +707,12 @@ function add_kernel_state!(@nospecialize(job::CompilerJob), mod::LLVM.Module,
707707
elseif val isa LLVM.CallBase
708708
# the function is being passed as an argument, which we'll just permit,
709709
# because we expect to have rewritten the call down the line separately.
710+
elseif val isa LLVM.StoreInst
711+
# the function is being stored, which again we'll permit like before.
710712
elseif val isa ConstantExpr
711713
rewrite_uses!(val)
712714
else
713-
error("Cannot rewrite unknown use of function: $val")
715+
error("Cannot rewrite $(typeof(val)) use of function: $val")
714716
end
715717
end
716718
end

src/optim.jl

Lines changed: 4 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -192,6 +192,8 @@ function optimize!(@nospecialize(job::CompilerJob), mod::LLVM.Module)
192192
if job.source.kernel
193193
# GC lowering is the last pass that may introduce calls to the runtime library,
194194
# and thus additional uses of the kernel state intrinsic.
195+
# TODO: now that all kernel state-related passes are being run here, merge some?
196+
add!(pm, ModulePass("AddKernelState", add_kernel_state!))
195197
add!(pm, FunctionPass("LowerKernelState", lower_kernel_state!))
196198
add!(pm, ModulePass("CleanupKernelState", cleanup_kernel_state!))
197199
end
@@ -253,8 +255,7 @@ function optimize!(@nospecialize(job::CompilerJob), mod::LLVM.Module)
253255
ModulePassManager() do pm
254256
addTargetPasses!(pm, tm, triple)
255257

256-
# - remove unused kernel state arguments
257-
# - simplify function calls that don't use the returned value
258+
# simplify function calls that don't use the returned value
258259
dead_arg_elimination!(pm)
259260

260261
run!(pm, mod)
@@ -354,16 +355,8 @@ function lower_gc_frame!(fun::LLVM.Function)
354355

355356
# replace with PTX alloc_obj
356357
Builder(ctx) do builder
357-
# NOTE: this happens late during the pipeline, where we may have to
358-
# pass a kernel state arguments to the runtime function.
359-
state = if job.source.kernel
360-
kernel_state_type(job)
361-
else
362-
Nothing
363-
end
364-
365358
position!(builder, call)
366-
ptr = call!(builder, Runtime.get(:gc_pool_alloc), [sz]; state)
359+
ptr = call!(builder, Runtime.get(:gc_pool_alloc), [sz])
367360
replace_uses!(call, ptr)
368361
end
369362

src/ptx.jl

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -185,8 +185,17 @@ function finish_module!(@nospecialize(job::CompilerJob{PTXCompilerTarget}),
185185
if job.source.kernel
186186
# work around bad byval codegen (JuliaGPU/GPUCompiler.jl#92)
187187
entry = lower_byval(job, mod, entry)
188-
# TODO: optimization passes to clean-up byval
188+
end
189+
190+
return entry
191+
end
189192

193+
function finish_ir!(@nospecialize(job::CompilerJob{PTXCompilerTarget}),
194+
mod::LLVM.Module, entry::LLVM.Function)
195+
ctx = context(mod)
196+
entry = invoke(finish_ir!, Tuple{CompilerJob, LLVM.Module, LLVM.Function}, job, mod, entry)
197+
198+
if job.source.kernel
190199
# add metadata annotations for the assembler to the module
191200

192201
# property annotations

src/rtlib.jl

Lines changed: 2 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -28,8 +28,7 @@ end
2828

2929
## higher-level functionality to work with runtime functions
3030

31-
function LLVM.call!(builder, rt::Runtime.RuntimeMethodInstance, args=LLVM.Value[];
32-
state::Type=Nothing)
31+
function LLVM.call!(builder, rt::Runtime.RuntimeMethodInstance, args=LLVM.Value[])
3332
bb = position(builder)
3433
f = LLVM.parent(bb)
3534
mod = LLVM.parent(f)
@@ -40,21 +39,10 @@ function LLVM.call!(builder, rt::Runtime.RuntimeMethodInstance, args=LLVM.Value[
4039
f = functions(mod)[rt.llvm_name]
4140
ft = eltype(llvmtype(f))
4241
else
43-
ft = convert(LLVM.FunctionType, rt; ctx, state)
42+
ft = convert(LLVM.FunctionType, rt; ctx)
4443
f = LLVM.Function(mod, rt.llvm_name, ft)
4544
end
4645

47-
# we may be calling this function after kernel state lowering,
48-
# in which case we need to manually get and pass the state.
49-
args = Value[args...]
50-
if state !== Nothing
51-
T_state = convert(LLVMType, state; ctx)
52-
53-
state_intr = kernel_state_intr(mod, T_state)
54-
state_val = call!(builder, state_intr, Value[], "state")
55-
pushfirst!(args, state_val)
56-
end
57-
5846
# runtime functions are written in Julia, while we're calling from LLVM,
5947
# this often results in argument type mismatches. try to fix some here.
6048
for (i,arg) in enumerate(args)

src/runtime.jl

Lines changed: 1 addition & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -34,19 +34,13 @@ struct RuntimeMethodInstance
3434
end
3535

3636
function Base.convert(::Type{LLVM.FunctionType}, rt::RuntimeMethodInstance;
37-
ctx::LLVM.Context, state::Type=Nothing)
37+
ctx::LLVM.Context)
3838
types = if rt.llvm_types === nothing
3939
LLVMType[convert(LLVMType, typ; ctx, allow_boxed=true) for typ in rt.types]
4040
else
4141
rt.llvm_types(ctx)
4242
end
4343

44-
# if we're running post-optimization, prepend the kernel state to the argument list
45-
if state !== Nothing
46-
T_state = convert(LLVMType, state; ctx)
47-
pushfirst!(types, T_state)
48-
end
49-
5044
return_type = if rt.llvm_return_type === nothing
5145
convert(LLVMType, rt.return_type; ctx, allow_boxed=true)
5246
else

src/spirv.jl

Lines changed: 3 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -232,26 +232,15 @@ function wrap_byval(@nospecialize(job::CompilerJob), mod::LLVM.Module, f::LLVM.F
232232
end
233233
else
234234
# XXX: byval is not round-trippable on LLVM < 12 (see maleadt/LLVM.jl#186)
235-
has_kernel_state = kernel_state_type(job) !== Nothing
236-
orig_ft = if has_kernel_state
237-
# the kernel state has been added here already, so strip the first parameter
238-
LLVM.FunctionType(LLVM.return_type(ft), parameters(ft)[2:end]; vararg=isvararg(ft))
239-
else
240-
ft
241-
end
242-
args = classify_arguments(job, orig_ft)
235+
args = classify_arguments(job, ft)
243236
filter!(args) do arg
244237
arg.cc != GHOST
245238
end
246239
for arg in args
247240
if arg.cc == BITS_REF
248-
# NOTE: +1 since this pass runs after introducing the kernel state
249-
byval[arg.codegen.i+has_kernel_state] = true
241+
byval[arg.codegen.i] = true
250242
end
251243
end
252-
if has_kernel_state
253-
byval[1] = true
254-
end
255244
end
256245

257246
# generate the wrapper function type & definition
@@ -317,6 +306,7 @@ function wrap_byval(@nospecialize(job::CompilerJob), mod::LLVM.Module, f::LLVM.F
317306
# NOTE: if we ever have legitimate uses of the old function, create a shim instead
318307
fn = LLVM.name(f)
319308
@assert isempty(uses(f))
309+
replace_metadata_uses!(f, new_f)
320310
unsafe_delete!(mod, f)
321311
LLVM.name!(new_f, fn)
322312

0 commit comments

Comments
 (0)