Skip to content

Commit 1f87ce6

Browse files
committed
Rewrite byval passes using clone utils.
Also have them look for the byval attribute instead of processing arguments from scratch again.
1 parent de62426 commit 1f87ce6

File tree

5 files changed

+120
-141
lines changed

5 files changed

+120
-141
lines changed

Manifest.toml

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -39,15 +39,17 @@ version = "1.3.0"
3939

4040
[[LLVM]]
4141
deps = ["CEnum", "LLVMExtra_jll", "Libdl", "Printf", "Unicode"]
42-
git-tree-sha1 = "d6041ad706cf458b2c9f3e501152488a26451e9c"
42+
git-tree-sha1 = "effe3552dba16b1e9c3cc9beac454a3566f09637"
43+
repo-rev = "b4dfdfcf86dde4563c87e90d130c78dbbe8550f8"
44+
repo-url = "https://github.com/maleadt/LLVM.jl.git"
4345
uuid = "929cbde3-209d-540e-8aea-75f648917ca0"
4446
version = "4.2.0"
4547

4648
[[LLVMExtra_jll]]
4749
deps = ["Artifacts", "JLLWrappers", "Libdl", "Pkg"]
48-
git-tree-sha1 = "a9b1130c4728b0e462a1c28772954650039eb847"
50+
git-tree-sha1 = "873e7962f14f6bdd8a0e10552d964ec0a7c69f3b"
4951
uuid = "dad2f222-ce93-54a1-a47d-0025e8a3acab"
50-
version = "0.0.7+0"
52+
version = "0.0.9+0"
5153

5254
[[LibCURL]]
5355
deps = ["LibCURL_jll", "MozillaCACerts_jll"]

src/gcn.jl

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,8 @@ function process_module!(job::CompilerJob{GCNCompilerTarget}, mod::LLVM.Module)
4444
end
4545

4646
function process_entry!(job::CompilerJob{GCNCompilerTarget}, mod::LLVM.Module, entry::LLVM.Function)
47+
invoke(process_entry!, Tuple{CompilerJob, LLVM.Module, LLVM.Function}, job, mod, entry)
48+
4749
if job.source.kernel
4850
entry = lower_byval(job, mod, entry)
4951

src/irgen.jl

Lines changed: 52 additions & 88 deletions
Original file line numberDiff line numberDiff line change
@@ -382,114 +382,78 @@ end
382382
# some back-ends don't support byval, or support it badly
383383
# https://reviews.llvm.org/D79744
384384

385-
# generate a kernel wrapper to fix & improve argument passing
386-
function lower_byval(@nospecialize(job::CompilerJob), mod::LLVM.Module, entry_f::LLVM.Function)
385+
# modify the kernel function to fix & improve argument passing
386+
function lower_byval(@nospecialize(job::CompilerJob), mod::LLVM.Module, f::LLVM.Function)
387387
ctx = context(mod)
388-
entry_ft = eltype(llvmtype(entry_f)::LLVM.PointerType)::LLVM.FunctionType
389-
@compiler_assert return_type(entry_ft) == LLVM.VoidType(ctx) job
390-
391-
args = classify_arguments(job, entry_f)
392-
filter!(args) do arg
393-
arg.cc != GHOST
388+
ft = eltype(llvmtype(f)::LLVM.PointerType)::LLVM.FunctionType
389+
@compiler_assert return_type(ft) == LLVM.VoidType(ctx) job
390+
391+
# find the byval parameters
392+
byval = BitVector(undef, length(parameters(ft)))
393+
for i in 1:length(byval)
394+
attrs = collect(parameter_attributes(f, i))
395+
byval[i] = any(attrs) do attr
396+
kind(attr) == kind(EnumAttribute("byval", 0; ctx))
397+
end
394398
end
395399

396-
# generate the wrapper function type & definition
397-
wrapper_types = LLVM.LLVMType[]
398-
for arg in args
399-
typ = if arg.cc == BITS_REF
400-
eltype(arg.codegen.typ)
400+
# generate the new function type & definition
401+
new_types = LLVM.LLVMType[]
402+
for (i, param) in enumerate(parameters(ft))
403+
if byval[i]
404+
push!(new_types, eltype(param::LLVM.PointerType))
401405
else
402-
convert(LLVMType, arg.typ; ctx)
406+
push!(new_types, param)
403407
end
404-
push!(wrapper_types, typ)
405408
end
406-
wrapper_fn = LLVM.name(entry_f)
407-
LLVM.name!(entry_f, wrapper_fn * ".inner")
408-
wrapper_ft = LLVM.FunctionType(LLVM.VoidType(ctx), wrapper_types)
409-
wrapper_f = LLVM.Function(mod, wrapper_fn, wrapper_ft)
409+
new_ft = LLVM.FunctionType(return_type(ft), new_types)
410+
new_f = LLVM.Function(mod, "", new_ft)
411+
linkage!(new_f, linkage(f))
410412

411413
# emit IR performing the "conversions"
412-
let builder = Builder(ctx)
413-
entry = BasicBlock(wrapper_f, "entry"; ctx)
414+
new_args = LLVM.Value[]
415+
Builder(ctx) do builder
416+
entry = BasicBlock(new_f, "entry"; ctx)
414417
position!(builder, entry)
415418

416-
wrapper_args = Vector{LLVM.Value}()
417-
418419
# perform argument conversions
419-
for arg in args
420-
if arg.cc == BITS_REF
420+
for (i, param) in enumerate(parameters(ft))
421+
if byval[i]
421422
# copy the argument value to a stack slot, and reference it.
422-
ptr = alloca!(builder, eltype(arg.codegen.typ))
423-
if LLVM.addrspace(arg.codegen.typ) != 0
424-
ptr = addrspacecast!(builder, ptr, arg.codegen.typ)
423+
ptr = alloca!(builder, eltype(param))
424+
if LLVM.addrspace(param) != 0
425+
ptr = addrspacecast!(builder, ptr, param)
425426
end
426-
store!(builder, parameters(wrapper_f)[arg.codegen.i], ptr)
427-
push!(wrapper_args, ptr)
427+
store!(builder, parameters(new_f)[i], ptr)
428+
push!(new_args, ptr)
428429
else
429-
push!(wrapper_args, parameters(wrapper_f)[arg.codegen.i])
430-
for attr in collect(parameter_attributes(entry_f, arg.codegen.i))
431-
push!(parameter_attributes(wrapper_f, arg.codegen.i), attr)
430+
push!(new_args, parameters(new_f)[i])
431+
for attr in collect(parameter_attributes(f, i))
432+
push!(parameter_attributes(new_f, i), attr)
432433
end
433434
end
434435
end
435436

436-
call!(builder, entry_f, wrapper_args)
437-
438-
ret!(builder)
439-
440-
dispose(builder)
441-
end
442-
443-
# early-inline the original entry function into the wrapper
444-
push!(function_attributes(entry_f), EnumAttribute("alwaysinline", 0; ctx))
445-
linkage!(entry_f, LLVM.API.LLVMInternalLinkage)
437+
# inline the old IR
438+
value_map = Dict{LLVM.Value, LLVM.Value}(
439+
param => new_args[i] for (i,param) in enumerate(parameters(f))
440+
)
441+
clone_into!(new_f, f; value_map,
442+
changes=LLVM.API.LLVMCloneFunctionChangeTypeGlobalChanges)
443+
# NOTE: we need global changes because LLVM 12 wants to clone debug metadata
446444

447-
# copy debug info
448-
sp = LLVM.get_subprogram(entry_f)
449-
if sp !== nothing
450-
LLVM.set_subprogram!(wrapper_f, sp)
445+
# fall through
446+
br!(builder, collect(blocks(new_f))[2])
451447
end
452448

453-
fixup_metadata!(entry_f)
454-
ModulePassManager() do pm
455-
always_inliner!(pm)
456-
run!(pm, mod)
457-
end
449+
# remove the old function
450+
# NOTE: if we ever have legitimate uses of the old function, create a shim instead
451+
fn = LLVM.name(f)
452+
@assert isempty(uses(f))
453+
# XXX: there may still be metadata using this function. RAUW updates those,
454+
# but asserts on a debug build due to the updated function type.
455+
unsafe_delete!(mod, f)
456+
LLVM.name!(new_f, fn)
458457

459-
return wrapper_f
460-
end
461-
462-
# HACK: get rid of invariant.load and const TBAA metadata on loads from pointer args,
463-
# since storing to a stack slot violates the semantics of those attributes.
464-
# TODO: can we emit a wrapper that doesn't violate Julia's metadata?
465-
function fixup_metadata!(f::LLVM.Function)
466-
for param in parameters(f)
467-
if isa(llvmtype(param), LLVM.PointerType)
468-
# collect all uses of the pointer
469-
worklist = Vector{LLVM.Instruction}(user.(collect(uses(param))))
470-
while !isempty(worklist)
471-
value = popfirst!(worklist)
472-
473-
# remove the invariant.load attribute
474-
md = metadata(value)
475-
if haskey(md, LLVM.MD_invariant_load)
476-
delete!(md, LLVM.MD_invariant_load)
477-
end
478-
if haskey(md, LLVM.MD_tbaa)
479-
delete!(md, LLVM.MD_tbaa)
480-
end
481-
482-
# recurse on the output of some instructions
483-
if isa(value, LLVM.BitCastInst) ||
484-
isa(value, LLVM.GetElementPtrInst) ||
485-
isa(value, LLVM.AddrSpaceCastInst)
486-
append!(worklist, user.(collect(uses(value))))
487-
end
488-
489-
# IMPORTANT NOTE: if we ever want to inline functions at the LLVM level,
490-
# we need to recurse into call instructions here, and strip metadata from
491-
# called functions (see CUDAnative.jl#238).
492-
end
493-
end
494-
end
458+
return new_f
495459
end

src/ptx.jl

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -114,8 +114,9 @@ end
114114

115115
function process_entry!(@nospecialize(job::CompilerJob{PTXCompilerTarget}),
116116
mod::LLVM.Module, entry::LLVM.Function)
117-
ctx = context(mod)
117+
invoke(process_entry!, Tuple{CompilerJob, LLVM.Module, LLVM.Function}, job, mod, entry)
118118

119+
ctx = context(mod)
119120
if job.source.kernel
120121
# work around bad byval codegen (JuliaGPU/GPUCompiler.jl#92)
121122
entry = lower_byval(job, mod, entry)

src/spirv.jl

Lines changed: 59 additions & 49 deletions
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,8 @@ function process_module!(job::CompilerJob{SPIRVCompilerTarget}, mod::LLVM.Module
4040
end
4141

4242
function process_entry!(job::CompilerJob{SPIRVCompilerTarget}, mod::LLVM.Module, entry::LLVM.Function)
43+
invoke(process_entry!, Tuple{CompilerJob, LLVM.Module, LLVM.Function}, job, mod, entry)
44+
4345
if job.source.kernel
4446
# HACK: Intel's compute runtime doesn't properly support SPIR-V's byval attribute.
4547
# they do support struct byval, for OpenCL, so wrap byval parameters in a struct.
@@ -193,75 +195,83 @@ function rm_freeze!(mod::LLVM.Module)
193195
end
194196

195197
# wrap byval pointers in a single-value struct
196-
function wrap_byval(@nospecialize(job::CompilerJob), mod::LLVM.Module, entry_f::LLVM.Function)
198+
function wrap_byval(@nospecialize(job::CompilerJob), mod::LLVM.Module, f::LLVM.Function)
197199
ctx = context(mod)
198-
entry_ft = eltype(llvmtype(entry_f)::LLVM.PointerType)::LLVM.FunctionType
199-
@compiler_assert return_type(entry_ft) == LLVM.VoidType(ctx) job
200-
201-
args = classify_arguments(job, entry_f)
202-
filter!(args) do arg
203-
arg.cc != GHOST
200+
ft = eltype(llvmtype(f)::LLVM.PointerType)::LLVM.FunctionType
201+
@compiler_assert return_type(ft) == LLVM.VoidType(ctx) job
202+
203+
# find the byval parameters
204+
byval = BitVector(undef, length(parameters(ft)))
205+
for i in 1:length(byval)
206+
attrs = collect(parameter_attributes(f, i))
207+
byval[i] = any(attrs) do attr
208+
kind(attr) == kind(EnumAttribute("byval", 0; ctx))
209+
end
204210
end
205211

206212
# generate the wrapper function type & definition
207-
wrapper_types = LLVM.LLVMType[]
208-
for arg in args
209-
typ = if arg.cc == BITS_REF
210-
st = LLVM.StructType([eltype(arg.codegen.typ)]; ctx)
211-
LLVM.PointerType(st, addrspace(arg.codegen.typ))
213+
new_types = LLVM.LLVMType[]
214+
for (i, param) in enumerate(parameters(ft))
215+
typ = if byval[i]
216+
st = LLVM.StructType([eltype(param)]; ctx)
217+
LLVM.PointerType(st, addrspace(param))
212218
else
213-
convert(LLVMType, arg.typ; ctx)
219+
param
214220
end
215-
push!(wrapper_types, typ)
221+
push!(new_types, typ)
216222
end
217-
wrapper_fn = LLVM.name(entry_f)
218-
LLVM.name!(entry_f, wrapper_fn * ".inner")
219-
wrapper_ft = LLVM.FunctionType(LLVM.VoidType(ctx), wrapper_types)
220-
wrapper_f = LLVM.Function(mod, wrapper_fn, wrapper_ft)
223+
new_ft = LLVM.FunctionType(LLVM.VoidType(ctx), new_types)
224+
new_f = LLVM.Function(mod, "", new_ft)
225+
linkage!(new_f, linkage(f))
221226

222227
# emit IR performing the "conversions"
223-
let builder = Builder(ctx)
224-
entry = BasicBlock(wrapper_f, "entry"; ctx)
228+
new_args = Vector{LLVM.Value}()
229+
Builder(ctx) do builder
230+
entry = BasicBlock(new_f, "entry"; ctx)
225231
position!(builder, entry)
226232

227-
wrapper_args = Vector{LLVM.Value}()
228-
229233
# perform argument conversions
230-
for arg in args
231-
param = parameters(wrapper_f)[arg.codegen.i]
232-
attrs = parameter_attributes(wrapper_f, arg.codegen.i)
233-
if arg.cc == BITS_REF
234+
for (i, param) in enumerate(parameters(new_f))
235+
if byval[i]
236+
ptr = struct_gep!(builder, param, 0)
237+
push!(new_args, ptr)
238+
else
239+
push!(new_args, param)
240+
end
241+
end
242+
243+
# inline the old IR
244+
value_map = Dict{LLVM.Value, LLVM.Value}(
245+
param => new_args[i] for (i,param) in enumerate(parameters(f))
246+
)
247+
clone_into!(new_f, f; value_map,
248+
changes=LLVM.API.LLVMCloneFunctionChangeTypeGlobalChanges)
249+
# NOTE: we need global changes because LLVM 12 wants to clone debug metadata
250+
251+
# apply byval attributes again (`clone_into!` didn't due to the type mismatch)
252+
for i in 1:length(byval)
253+
attrs = parameter_attributes(new_f, i)
254+
if byval[i]
234255
if LLVM.version() >= v"12"
235-
push!(attrs, TypeAttribute("byval", eltype(wrapper_types[arg.codegen.i]); ctx))
256+
push!(attrs, TypeAttribute("byval", eltype(new_types[i]); ctx))
236257
else
237258
push!(attrs, EnumAttribute("byval", 0; ctx))
238259
end
239-
ptr = struct_gep!(builder, param, 0)
240-
push!(wrapper_args, ptr)
241-
else
242-
push!(wrapper_args, param)
243-
for attr in collect(attrs)
244-
push!(parameter_attributes(wrapper_f, arg.codegen.i), attr)
245-
end
246260
end
247261
end
248262

249-
call!(builder, entry_f, wrapper_args)
250-
251-
ret!(builder)
252-
253-
dispose(builder)
263+
# fall through
264+
br!(builder, collect(blocks(new_f))[2])
254265
end
255266

256-
# early-inline the original entry function into the wrapper
257-
delete!(function_attributes(entry_f), EnumAttribute("noinline", 0; ctx))
258-
push!(function_attributes(entry_f), EnumAttribute("alwaysinline", 0; ctx))
259-
linkage!(entry_f, LLVM.API.LLVMInternalLinkage)
260-
261-
ModulePassManager() do pm
262-
always_inliner!(pm)
263-
run!(pm, mod)
264-
end
267+
# remove the old function
268+
# NOTE: if we ever have legitimate uses of the old function, create a shim instead
269+
fn = LLVM.name(f)
270+
@assert isempty(uses(f))
271+
# XXX: there may still be metadata using this function. RAUW updates those,
272+
# but asserts on a debug build due to the updated function type.
273+
unsafe_delete!(mod, f)
274+
LLVM.name!(new_f, fn)
265275

266-
return wrapper_f
276+
return new_f
267277
end

0 commit comments

Comments
 (0)