Skip to content

Commit 070406e

Browse files
authored
New interface: process_linked_module! (#727)
2 parents aab6333 + 1722ae3 commit 070406e

File tree

6 files changed

+112
-72
lines changed

6 files changed

+112
-72
lines changed

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
name = "GPUCompiler"
22
uuid = "61eb1bfa-7361-4325-ad38-22787b887f55"
3-
version = "1.7.0"
3+
version = "1.7.1"
44
authors = ["Tim Besard <[email protected]>"]
55

66
[deps]

src/driver.jl

Lines changed: 54 additions & 50 deletions
Original file line numberDiff line numberDiff line change
@@ -282,6 +282,14 @@ const __llvm_initialized = Ref(false)
282282
erase!(call)
283283
end
284284
end
285+
286+
# minimal optimization to convert the inttoptr/call into a direct call
287+
@dispose pb=NewPMPassBuilder() begin
288+
add!(pb, NewPMFunctionPassManager()) do fpm
289+
add!(fpm, InstCombinePass())
290+
end
291+
run!(pb, ir, llvm_machine(job.config.target))
292+
end
285293
end
286294

287295
# all deferred compilations should have been resolved
@@ -312,10 +320,15 @@ const __llvm_initialized = Ref(false)
312320
end
313321

314322
@tracepoint "IR post-processing" begin
315-
# mark everything internal except for entrypoints and any exported
316-
# global variables. this makes sure that the optimizer can, e.g.,
317-
# rewrite function signatures.
323+
# mark the kernel entry-point functions (optimization may need it)
324+
if job.config.kernel
325+
mark_kernel!(entry)
326+
end
327+
318328
if job.config.toplevel
329+
# mark everything internal except for entrypoints and any exported
330+
# global variables. this makes sure that the optimizer can, e.g.,
331+
# rewrite function signatures.
319332
preserved_gvs = collect(values(jobs))
320333
for gvar in globals(ir)
321334
if linkage(gvar) == LLVM.API.LLVMExternalLinkage
@@ -331,64 +344,55 @@ const __llvm_initialized = Ref(false)
331344
run!(pm, ir)
332345
end
333346
end
334-
end
335-
336-
# mark the kernel entry-point functions (optimization may need it)
337-
if job.config.kernel
338-
push!(metadata(ir)["julia.kernel"], MDNode([entry]))
339-
340-
# IDEA: save all jobs, not only kernels, and save other attributes
341-
# so that we can reconstruct the CompileJob instead of setting it globally
342-
end
343347

344-
if job.config.toplevel && job.config.optimize
345-
@tracepoint "optimization" begin
346-
optimize!(job, ir; job.config.opt_level)
348+
finish_linked_module!(job, ir)
349+
350+
if job.config.optimize
351+
@tracepoint "optimization" begin
352+
optimize!(job, ir; job.config.opt_level)
353+
354+
# deferred codegen has some special optimization requirements,
355+
# which also need to happen _after_ regular optimization.
356+
# XXX: make these part of the optimizer pipeline?
357+
if has_deferred_jobs
358+
@dispose pb=NewPMPassBuilder() begin
359+
add!(pb, NewPMFunctionPassManager()) do fpm
360+
add!(fpm, InstCombinePass())
361+
end
362+
add!(pb, AlwaysInlinerPass())
363+
add!(pb, NewPMFunctionPassManager()) do fpm
364+
add!(fpm, SROAPass())
365+
add!(fpm, GVNPass())
366+
end
367+
add!(pb, MergeFunctionsPass())
368+
run!(pb, ir, llvm_machine(job.config.target))
369+
end
370+
end
371+
end
372+
end
347373

348-
# deferred codegen has some special optimization requirements,
349-
# which also need to happen _after_ regular optimization.
350-
# XXX: make these part of the optimizer pipeline?
351-
if has_deferred_jobs
374+
if job.config.cleanup
375+
@tracepoint "clean-up" begin
352376
@dispose pb=NewPMPassBuilder() begin
353-
add!(pb, NewPMFunctionPassManager()) do fpm
354-
add!(fpm, InstCombinePass())
355-
end
356-
add!(pb, AlwaysInlinerPass())
357-
add!(pb, NewPMFunctionPassManager()) do fpm
358-
add!(fpm, SROAPass())
359-
add!(fpm, GVNPass())
360-
end
361-
add!(pb, MergeFunctionsPass())
377+
add!(pb, RecomputeGlobalsAAPass())
378+
add!(pb, GlobalOptPass())
379+
add!(pb, GlobalDCEPass())
380+
add!(pb, StripDeadPrototypesPass())
381+
add!(pb, ConstantMergePass())
362382
run!(pb, ir, llvm_machine(job.config.target))
363383
end
364384
end
365385
end
366386

367387
# optimization may have replaced functions, so look the entry point up again
368388
entry = functions(ir)[entry_fn]
369-
end
370389

371-
if job.config.toplevel && job.config.cleanup
372-
@tracepoint "clean-up" begin
373-
@dispose pb=NewPMPassBuilder() begin
374-
add!(pb, RecomputeGlobalsAAPass())
375-
add!(pb, GlobalOptPass())
376-
add!(pb, GlobalDCEPass())
377-
add!(pb, StripDeadPrototypesPass())
378-
add!(pb, ConstantMergePass())
379-
run!(pb, ir, llvm_machine(job.config.target))
380-
end
381-
end
382-
end
383-
384-
# finish the module
385-
#
386-
# we want to finish the module after optimization, so we cannot do so
387-
# during deferred code generation. instead, process the deferred jobs
388-
# here.
389-
if job.config.toplevel
390+
# finish the module
391+
#
392+
# we want to finish the module after optimization, so we cannot do so
393+
# during deferred code generation. instead, process the deferred jobs
394+
# here.
390395
entry = finish_ir!(job, ir, entry)
391-
392396
for (job′, fn′) in jobs
393397
job′ == job && continue
394398
finish_ir!(job′, ir, functions(ir)[fn′])
@@ -409,7 +413,7 @@ const __llvm_initialized = Ref(false)
409413
end
410414

411415
if job.config.toplevel && job.config.validate
412-
@tracepoint "Validation" begin
416+
@tracepoint "validation" begin
413417
check_ir(job, ir)
414418
end
415419
end

src/interface.jl

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -354,6 +354,9 @@ link_libraries!(@nospecialize(job::CompilerJob), mod::LLVM.Module,
354354
finish_module!(@nospecialize(job::CompilerJob), mod::LLVM.Module, entry::LLVM.Function) =
355355
entry
356356

357+
# finalization of linked modules, after deferred codegen but before optimization
358+
finish_linked_module!(@nospecialize(job::CompilerJob), mod::LLVM.Module) = return
359+
357360
# post-Julia optimization processing of the module
358361
optimize_module!(@nospecialize(job::CompilerJob), mod::LLVM.Module) = return
359362

src/irgen.jl

Lines changed: 20 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -526,18 +526,12 @@ function add_kernel_state!(mod::LLVM.Module)
526526
state_intr = kernel_state_intr(mod, T_state)
527527
state_intr_ft = LLVM.FunctionType(T_state)
528528

529-
kernels = []
530-
kernels_md = metadata(mod)["julia.kernel"]
531-
for kernel_md in operands(kernels_md)
532-
push!(kernels, Value(operands(kernel_md)[1]))
533-
end
534-
535529
# determine which functions need a kernel state argument
536530
#
537531
# previously, we add the argument to every function and relied on unused arg elim to
538532
# clean-up the IR. however, some libraries do Funny Stuff, e.g., libdevice bitcasting
539533
# function pointers. such IR is hard to rewrite, so instead be more conservative.
540-
worklist = Set{LLVM.Function}([state_intr, kernels...])
534+
worklist = Set{LLVM.Function}([state_intr, kernels(mod)...])
541535
worklist_length = 0
542536
while worklist_length != length(worklist)
543537
# iteratively discover functions that use the intrinsic or any function calling it
@@ -941,12 +935,24 @@ function add_input_arguments!(@nospecialize(job::CompilerJob), mod::LLVM.Module,
941935
while worklist_length != length(worklist)
942936
# iteratively discover functions that use an intrinsic or any function calling it
943937
worklist_length = length(worklist)
944-
additions = LLVM.Function[]
945-
for f in worklist, use in uses(f)
946-
inst = user(use)::Instruction
947-
bb = LLVM.parent(inst)
948-
new_f = LLVM.parent(bb)
949-
in(new_f, worklist) || push!(additions, new_f)
938+
additions = Set{LLVM.Function}()
939+
function scan_uses(val)
940+
for use in uses(val)
941+
candidate = user(use)
942+
if isa(candidate, Instruction)
943+
bb = LLVM.parent(candidate)
944+
new_f = LLVM.parent(bb)
945+
in(new_f, worklist) || push!(additions, new_f)
946+
elseif isa(candidate, ConstantExpr)
947+
@safe_info candidate
948+
scan_uses(candidate)
949+
else
950+
error("Don't know how to check uses of $candidate. Please file an issue.")
951+
end
952+
end
953+
end
954+
for f in worklist
955+
scan_uses(f)
950956
end
951957
for f in additions
952958
push!(worklist, f)
@@ -1054,6 +1060,7 @@ function add_input_arguments!(@nospecialize(job::CompilerJob), mod::LLVM.Module,
10541060
for (f, new_f) in workmap
10551061
rewrite_uses!(f, new_f)
10561062
@assert isempty(uses(f))
1063+
replace_metadata_uses!(f, new_f)
10571064
erase!(f)
10581065
end
10591066

src/metal.jl

Lines changed: 7 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -47,13 +47,11 @@ runtime_slug(job::CompilerJob{MetalCompilerTarget}) = "metal-macos$(job.config.t
4747
isintrinsic(@nospecialize(job::CompilerJob{MetalCompilerTarget}), fn::String) =
4848
return startswith(fn, "air.")
4949

50-
function finish_module!(@nospecialize(job::CompilerJob{MetalCompilerTarget}), mod::LLVM.Module, entry::LLVM.Function)
51-
entry_fn = LLVM.name(entry)
52-
53-
# update calling conventions
54-
if job.config.kernel
55-
entry = pass_by_reference!(job, mod, entry)
56-
entry = add_input_arguments!(job, mod, entry, kernel_intrinsics)
50+
function finish_linked_module!(@nospecialize(job::CompilerJob{MetalCompilerTarget}), mod::LLVM.Module)
51+
for f in kernels(mod)
52+
# update calling conventions
53+
f = pass_by_reference!(job, mod, f)
54+
f = add_input_arguments!(job, mod, f, kernel_intrinsics)
5755
end
5856

5957
# emit the AIR and Metal version numbers as constants in the module. this makes it
@@ -83,7 +81,7 @@ function finish_module!(@nospecialize(job::CompilerJob{MetalCompilerTarget}), mo
8381
run!(pb, mod)
8482
end
8583

86-
return functions(mod)[entry_fn]
84+
return
8785
end
8886

8987
function validate_ir(job::CompilerJob{MetalCompilerTarget}, mod::LLVM.Module)
@@ -497,6 +495,7 @@ function pass_by_reference!(@nospecialize(job::CompilerJob), mod::LLVM.Module, f
497495
# NOTE: if we ever have legitimate uses of the old function, create a shim instead
498496
fn = LLVM.name(f)
499497
@assert isempty(uses(f))
498+
replace_metadata_uses!(f, new_f)
500499
erase!(f)
501500
LLVM.name!(new_f, fn)
502501

src/utils.jl

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -155,3 +155,30 @@ function prune_constexpr_uses!(root::LLVM.Value)
155155
end
156156
end
157157
end
158+
159+
160+
## kernel metadata handling
161+
162+
# kernels are encoded in the IR using the julia.kernel metadata.
163+
164+
# IDEA: don't only mark kernels, but all jobs, and save all attributes of the CompileJob
165+
# so that we can reconstruct the CompileJob instead of setting it globally
166+
167+
# mark a function as kernel
168+
function mark_kernel!(f::LLVM.Function)
169+
mod = LLVM.parent(f)
170+
push!(metadata(mod)["julia.kernel"], MDNode([f]))
171+
return f
172+
end
173+
174+
# iterate over all kernels in the module
175+
function kernels(mod::LLVM.Module)
176+
vals = LLVM.Function[]
177+
if haskey(metadata(mod), "julia.kernel")
178+
kernels_md = metadata(mod)["julia.kernel"]
179+
for kernel_md in operands(kernels_md)
180+
push!(vals, LLVM.Value(operands(kernel_md)[1]))
181+
end
182+
end
183+
return vals
184+
end

0 commit comments

Comments
 (0)