Skip to content

Commit 7873cc6

Browse files
authored
Merge pull request #234 from JuliaGPU/tb/byval_clone
Rewrite byval lowering using clone utils.
2 parents cc77b90 + 6d62b98 commit 7873cc6

File tree

6 files changed

+178
-135
lines changed

6 files changed

+178
-135
lines changed

Manifest.toml

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -39,15 +39,17 @@ version = "1.3.0"
3939

4040
[[LLVM]]
4141
deps = ["CEnum", "LLVMExtra_jll", "Libdl", "Printf", "Unicode"]
42-
git-tree-sha1 = "d6041ad706cf458b2c9f3e501152488a26451e9c"
42+
git-tree-sha1 = "97b3606f7e23cc7afd594d8a094f1f087f1d6511"
43+
repo-rev = "2c5db93"
44+
repo-url = "https://github.com/maleadt/LLVM.jl.git"
4345
uuid = "929cbde3-209d-540e-8aea-75f648917ca0"
4446
version = "4.2.0"
4547

4648
[[LLVMExtra_jll]]
4749
deps = ["Artifacts", "JLLWrappers", "Libdl", "Pkg"]
48-
git-tree-sha1 = "a9b1130c4728b0e462a1c28772954650039eb847"
50+
git-tree-sha1 = "873e7962f14f6bdd8a0e10552d964ec0a7c69f3b"
4951
uuid = "dad2f222-ce93-54a1-a47d-0025e8a3acab"
50-
version = "0.0.7+0"
52+
version = "0.0.9+0"
5153

5254
[[LibCURL]]
5355
deps = ["LibCURL_jll", "MozillaCACerts_jll"]

src/gcn.jl

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,8 @@ function process_module!(job::CompilerJob{GCNCompilerTarget}, mod::LLVM.Module)
4444
end
4545

4646
function process_entry!(job::CompilerJob{GCNCompilerTarget}, mod::LLVM.Module, entry::LLVM.Function)
47+
invoke(process_entry!, Tuple{CompilerJob, LLVM.Module, LLVM.Function}, job, mod, entry)
48+
4749
if job.source.kernel
4850
entry = lower_byval(job, mod, entry)
4951

@@ -109,7 +111,7 @@ function lower_throw_extra!(mod::LLVM.Module)
109111
end
110112

111113
# remove the call
112-
call_args = collect(operands(call))[1:end-1] # last arg is function itself
114+
call_args = operands(call)[1:end-1] # last arg is function itself
113115
unsafe_delete!(LLVM.parent(call), call)
114116

115117
# HACK: kill the exceptions' unused arguments

src/irgen.jl

Lines changed: 95 additions & 80 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,11 @@ function irgen(@nospecialize(job::CompilerJob), method_instance::Core.MethodInst
1515
if Sys.iswindows()
1616
personality!(llvmf, nothing)
1717
end
18+
19+
# remove the non-specialized jfptr functions
20+
if startswith(LLVM.name(llvmf), "jfptr_")
21+
unsafe_delete!(mod, llvmf)
22+
end
1823
end
1924

2025
# remove the exception-handling personality function
@@ -68,6 +73,9 @@ function irgen(@nospecialize(job::CompilerJob), method_instance::Core.MethodInst
6873
end
6974
internalize!(pm, exports)
7075

76+
# inline llvmcall bodies
77+
always_inliner!(pm)
78+
7179
can_throw(job) || add!(pm, ModulePass("LowerThrow", lower_throw!))
7280

7381
add_lowering_passes!(job, pm)
@@ -199,7 +207,7 @@ function lower_throw!(mod::LLVM.Module)
199207
end
200208

201209
# remove the call
202-
call_args = collect(operands(call))[1:end-1] # last arg is function itself
210+
call_args = operands(call)[1:end-1] # last arg is function itself
203211
unsafe_delete!(LLVM.parent(call), call)
204212

205213
# HACK: kill the exceptions' unused arguments
@@ -377,90 +385,41 @@ end
377385
# some back-ends don't support byval, or support it badly
378386
# https://reviews.llvm.org/D79744
379387

380-
# generate a kernel wrapper to fix & improve argument passing
381-
function lower_byval(@nospecialize(job::CompilerJob), mod::LLVM.Module, entry_f::LLVM.Function)
388+
# modify the kernel function to fix & improve argument passing
389+
function lower_byval(@nospecialize(job::CompilerJob), mod::LLVM.Module, f::LLVM.Function)
382390
ctx = context(mod)
383-
entry_ft = eltype(llvmtype(entry_f)::LLVM.PointerType)::LLVM.FunctionType
384-
@compiler_assert return_type(entry_ft) == LLVM.VoidType(ctx) job
385-
386-
args = classify_arguments(job, entry_f)
387-
filter!(args) do arg
388-
arg.cc != GHOST
389-
end
390-
391-
# generate the wrapper function type & definition
392-
wrapper_types = LLVM.LLVMType[]
393-
for arg in args
394-
typ = if arg.cc == BITS_REF
395-
eltype(arg.codegen.typ)
396-
else
397-
convert(LLVMType, arg.typ; ctx)
391+
ft = eltype(llvmtype(f)::LLVM.PointerType)::LLVM.FunctionType
392+
@compiler_assert return_type(ft) == LLVM.VoidType(ctx) job
393+
394+
# find the byval parameters
395+
byval = BitVector(undef, length(parameters(ft)))
396+
if LLVM.version() >= v"12"
397+
for i in 1:length(byval)
398+
attrs = collect(parameter_attributes(f, i))
399+
byval[i] = any(attrs) do attr
400+
kind(attr) == kind(EnumAttribute("byval", 0; ctx))
401+
end
402+
end
403+
else
404+
# XXX: byval is not round-trippable on LLVM < 12 (see maleadt/LLVM.jl#186)
405+
args = classify_arguments(job, f)
406+
filter!(args) do arg
407+
arg.cc != GHOST
398408
end
399-
push!(wrapper_types, typ)
400-
end
401-
wrapper_fn = LLVM.name(entry_f)
402-
LLVM.name!(entry_f, wrapper_fn * ".inner")
403-
wrapper_ft = LLVM.FunctionType(LLVM.VoidType(ctx), wrapper_types)
404-
wrapper_f = LLVM.Function(mod, wrapper_fn, wrapper_ft)
405-
406-
# emit IR performing the "conversions"
407-
let builder = Builder(ctx)
408-
entry = BasicBlock(wrapper_f, "entry"; ctx)
409-
position!(builder, entry)
410-
411-
wrapper_args = Vector{LLVM.Value}()
412-
413-
# perform argument conversions
414409
for arg in args
415410
if arg.cc == BITS_REF
416-
# copy the argument value to a stack slot, and reference it.
417-
ptr = alloca!(builder, eltype(arg.codegen.typ))
418-
if LLVM.addrspace(arg.codegen.typ) != 0
419-
ptr = addrspacecast!(builder, ptr, arg.codegen.typ)
420-
end
421-
store!(builder, parameters(wrapper_f)[arg.codegen.i], ptr)
422-
push!(wrapper_args, ptr)
423-
else
424-
push!(wrapper_args, parameters(wrapper_f)[arg.codegen.i])
425-
for attr in collect(parameter_attributes(entry_f, arg.codegen.i))
426-
push!(parameter_attributes(wrapper_f, arg.codegen.i), attr)
427-
end
411+
byval[arg.codegen.i] = true
428412
end
429413
end
430-
431-
call!(builder, entry_f, wrapper_args)
432-
433-
ret!(builder)
434-
435-
dispose(builder)
436414
end
437415

438-
# early-inline the original entry function into the wrapper
439-
push!(function_attributes(entry_f), EnumAttribute("alwaysinline", 0; ctx))
440-
linkage!(entry_f, LLVM.API.LLVMInternalLinkage)
441-
442-
# copy debug info
443-
sp = LLVM.get_subprogram(entry_f)
444-
if sp !== nothing
445-
LLVM.set_subprogram!(wrapper_f, sp)
446-
end
447-
448-
fixup_metadata!(entry_f)
449-
ModulePassManager() do pm
450-
always_inliner!(pm)
451-
run!(pm, mod)
452-
end
453-
454-
return wrapper_f
455-
end
456-
457-
# HACK: get rid of invariant.load and const TBAA metadata on loads from pointer args,
458-
# since storing to a stack slot violates the semantics of those attributes.
459-
# TODO: can we emit a wrapper that doesn't violate Julia's metadata?
460-
function fixup_metadata!(f::LLVM.Function)
461-
for param in parameters(f)
462-
if isa(llvmtype(param), LLVM.PointerType)
463-
# collect all uses of the pointer
416+
# fixup metadata
417+
#
418+
# Julia emits invariant.load and const TBAA metadta on loads from pointer args,
419+
# which is invalid now that we have materialized the byval.
420+
for (i, param) in enumerate(parameters(f))
421+
if byval[i]
422+
# collect all uses of the argument
464423
worklist = Vector{LLVM.Instruction}(user.(collect(uses(param))))
465424
while !isempty(worklist)
466425
value = popfirst!(worklist)
@@ -480,11 +439,67 @@ function fixup_metadata!(f::LLVM.Function)
480439
isa(value, LLVM.AddrSpaceCastInst)
481440
append!(worklist, user.(collect(uses(value))))
482441
end
442+
end
443+
end
444+
end
445+
446+
# generate the new function type & definition
447+
new_types = LLVM.LLVMType[]
448+
for (i, param) in enumerate(parameters(ft))
449+
if byval[i]
450+
push!(new_types, eltype(param::LLVM.PointerType))
451+
else
452+
push!(new_types, param)
453+
end
454+
end
455+
new_ft = LLVM.FunctionType(return_type(ft), new_types)
456+
new_f = LLVM.Function(mod, "", new_ft)
457+
linkage!(new_f, linkage(f))
458+
459+
# emit IR performing the "conversions"
460+
new_args = LLVM.Value[]
461+
Builder(ctx) do builder
462+
entry = BasicBlock(new_f, "entry"; ctx)
463+
position!(builder, entry)
483464

484-
# IMPORTANT NOTE: if we ever want to inline functions at the LLVM level,
485-
# we need to recurse into call instructions here, and strip metadata from
486-
# called functions (see CUDAnative.jl#238).
465+
# perform argument conversions
466+
for (i, param) in enumerate(parameters(ft))
467+
if byval[i]
468+
# copy the argument value to a stack slot, and reference it.
469+
ptr = alloca!(builder, eltype(param))
470+
if LLVM.addrspace(param) != 0
471+
ptr = addrspacecast!(builder, ptr, param)
472+
end
473+
store!(builder, parameters(new_f)[i], ptr)
474+
push!(new_args, ptr)
475+
else
476+
push!(new_args, parameters(new_f)[i])
477+
for attr in collect(parameter_attributes(f, i))
478+
push!(parameter_attributes(new_f, i), attr)
479+
end
487480
end
488481
end
482+
483+
# inline the old IR
484+
value_map = Dict{LLVM.Value, LLVM.Value}(
485+
param => new_args[i] for (i,param) in enumerate(parameters(f))
486+
)
487+
clone_into!(new_f, f; value_map,
488+
changes=LLVM.API.LLVMCloneFunctionChangeTypeGlobalChanges)
489+
# NOTE: we need global changes because LLVM 12 wants to clone debug metadata
490+
491+
# fall through
492+
br!(builder, blocks(new_f)[2])
489493
end
494+
495+
# remove the old function
496+
# NOTE: if we ever have legitimate uses of the old function, create a shim instead
497+
fn = LLVM.name(f)
498+
@assert isempty(uses(f))
499+
# XXX: there may still be metadata using this function. RAUW updates those,
500+
# but asserts on a debug build due to the updated function type.
501+
unsafe_delete!(mod, f)
502+
LLVM.name!(new_f, fn)
503+
504+
return new_f
490505
end

src/optim.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -242,7 +242,7 @@ function lower_gc_frame!(fun::LLVM.Function)
242242
call = user(use)::LLVM.CallInst
243243

244244
# decode the call
245-
ops = collect(operands(call))
245+
ops = operands(call)
246246
sz = ops[2]
247247

248248
# replace with PTX alloc_obj

src/ptx.jl

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -114,8 +114,9 @@ end
114114

115115
function process_entry!(@nospecialize(job::CompilerJob{PTXCompilerTarget}),
116116
mod::LLVM.Module, entry::LLVM.Function)
117-
ctx = context(mod)
117+
invoke(process_entry!, Tuple{CompilerJob, LLVM.Module, LLVM.Function}, job, mod, entry)
118118

119+
ctx = context(mod)
119120
if job.source.kernel
120121
# work around bad byval codegen (JuliaGPU/GPUCompiler.jl#92)
121122
entry = lower_byval(job, mod, entry)

0 commit comments

Comments
 (0)