Skip to content

Commit 40f4292

Browse files
authored
Merge pull request #322 from JuliaGPU/tb/kernel_state
Kernel state rewriting: support more IR patterns.
2 parents 0546dd9 + 0a22a88 commit 40f4292

File tree

2 files changed

+91
-21
lines changed

2 files changed

+91
-21
lines changed

src/irgen.jl

Lines changed: 89 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -563,11 +563,22 @@ function add_kernel_state!(@nospecialize(job::CompilerJob), mod::LLVM.Module,
563563
# iteratively discover functions that use the intrinsic or any function calling it
564564
worklist_length = length(worklist)
565565
additions = LLVM.Function[]
566+
function check_user(val)
567+
if val isa Instruction
568+
bb = LLVM.parent(val)
569+
new_f = LLVM.parent(bb)
570+
in(new_f, worklist) || push!(additions, new_f)
571+
elseif val isa ConstantExpr
572+
# constant expressions don't have a parent; we need to look up their uses
573+
for use in uses(val)
574+
check_user(user(use))
575+
end
576+
else
577+
error("Don't know how to check uses of $val. Please file an issue.")
578+
end
579+
end
566580
for f in worklist, use in uses(f)
567-
inst = user(use)::Instruction
568-
bb = LLVM.parent(inst)
569-
new_f = LLVM.parent(bb)
570-
in(new_f, worklist) || push!(additions, new_f)
581+
check_user(user(use))
571582
end
572583
for f in additions
573584
push!(worklist, f)
@@ -595,7 +606,39 @@ function add_kernel_state!(@nospecialize(job::CompilerJob), mod::LLVM.Module,
595606
workmap[f] = new_f
596607
end
597608

598-
# clone and rewrite the function bodies
609+
# clone and rewrite the function bodies, replacing uses of the old stateless function
610+
# with the newly created definition that includes the state argument.
611+
#
612+
# most uses are rewritten by LLVM by putting the functions in the value map.
613+
# a separate value materializer is used to recreate constant expressions.
614+
#
615+
# note that this only _replaces_ the uses of these functions, we'll still need to
616+
# _correct_ the uses (i.e. actually add the state argument) afterwards.
617+
function materializer(val)
618+
if val isa ConstantExpr
619+
if opcode(val) == LLVM.API.LLVMBitCast
620+
target = operands(val)[1]
621+
if target isa LLVM.Function && haskey(workmap, target)
622+
# the function is being bitcasted to a different function type.
623+
# we need to mutate that function type to include the state argument,
624+
# or we'd be invoking the original function in an invalid way.
625+
#
626+
# XXX: ptrtoint/inttoptr pairs can also lose the state argument...
627+
# is all this even sound?
628+
typ = llvmtype(val)::LLVM.PointerType
629+
ft = eltype(typ)::LLVM.FunctionType
630+
new_ft = LLVM.FunctionType(return_type(ft), [T_state, parameters(ft)...])
631+
return const_bitcast(workmap[target], LLVM.PointerType(new_ft, addrspace(typ)))
632+
end
633+
elseif opcode(val) == LLVM.API.LLVMPtrToInt
634+
target = operands(val)[1]
635+
if target isa LLVM.Function && haskey(workmap, target)
636+
return const_ptrtoint(workmap[target], llvmtype(val))
637+
end
638+
end
639+
end
640+
return val
641+
end
599642
for (f, new_f) in workmap
600643
# use a value mapper for rewriting function arguments
601644
value_map = Dict{LLVM.Value, LLVM.Value}()
@@ -604,30 +647,54 @@ function add_kernel_state!(@nospecialize(job::CompilerJob), mod::LLVM.Module,
604647
value_map[param] = new_param
605648
end
606649

607-
value_map[f] = new_f
608-
clone_into!(new_f, f; value_map,
650+
# rewrite references to the old function
651+
merge!(value_map, workmap)
652+
653+
clone_into!(new_f, f; value_map, materializer,
609654
changes=LLVM.API.LLVMCloneFunctionChangeTypeGlobalChanges)
610655

611-
# we can't remove this function yet, as we might still need to rewrite any called,
612-
# but remove the IR already
656+
# remove the function IR so that we won't have any uses left after this pass.
613657
empty!(f)
614658
end
615659

616-
# update other uses of the old function, modifying call sites to pass the state argument
617-
function rewrite_uses!(f, new_f)
660+
# ensure the old (stateless) functions don't have uses anymore, and remove them
661+
for f in keys(workmap)
662+
for use in uses(f)
663+
val = user(use)
664+
if val isa ConstantExpr
665+
# XXX: shouldn't clone_into! remove unused CEs?
666+
isempty(uses(val)) || error("old function still has uses (via a constant expr)")
667+
LLVM.unsafe_destroy!(val)
668+
else
669+
error("old function still has uses")
670+
end
671+
end
672+
unsafe_delete!(mod, f)
673+
end
674+
675+
# update uses of the new function, modifying call sites to include the kernel state
676+
function rewrite_uses!(f)
618677
# update uses
619678
Builder(ctx) do builder
620679
for use in uses(f)
621680
val = user(use)
622-
if val isa LLVM.CallInst || val isa LLVM.InvokeInst || val isa LLVM.CallBrInst
623-
# NOTE: we unconditionally add the state argument, even if there's no uses,
624-
# assuming we'll perform dead arg elimination during optimization.
681+
if val isa LLVM.CallBase && called_value(val) == f
682+
# NOTE: we don't rewrite calls using Julia's jlcall calling convention,
683+
# as those have a fixed argument list, passing actual arguments
684+
# in an array of objects. that doesn't matter, for now, since
685+
# GPU back-ends don't support such calls anyhow. but if we ever
686+
# want to support kernel state passing on more capable back-ends,
687+
# we'll need to update the argument array instead.
688+
if callconv(val) == 37 || callconv(val) == 38
689+
# TODO: update for LLVM 15 when JuliaLang/julia#45088 is merged.
690+
continue
691+
end
625692

626693
# forward the state argument
627694
position!(builder, val)
628695
state = call!(builder, state_intr, Value[], "state")
629696
new_val = if val isa LLVM.CallInst
630-
call!(builder, new_f, [state, arguments(val)...], operand_bundles(val))
697+
call!(builder, f, [state, arguments(val)...], operand_bundles(val))
631698
else
632699
# TODO: invoke and callbr
633700
error("Rewrite of $(typeof(val))-based calls is not implemented: $val")
@@ -637,16 +704,19 @@ function add_kernel_state!(@nospecialize(job::CompilerJob), mod::LLVM.Module,
637704
replace_uses!(val, new_val)
638705
@assert isempty(uses(val))
639706
unsafe_delete!(LLVM.parent(val), val)
707+
elseif val isa LLVM.CallBase
708+
# the function is being passed as an argument, which we'll just permit,
709+
# because we expect to have rewritten the call down the line separately.
710+
elseif val isa ConstantExpr
711+
rewrite_uses!(val)
640712
else
641713
error("Cannot rewrite unknown use of function: $val")
642714
end
643715
end
644716
end
645717
end
646-
for (f, new_f) in workmap
647-
rewrite_uses!(f, new_f)
648-
@assert isempty(uses(f))
649-
unsafe_delete!(mod, f)
718+
for f in values(workmap)
719+
rewrite_uses!(f)
650720
end
651721

652722
return true

src/validation.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -236,8 +236,8 @@ function check_ir!(job, errors::Vector{IRError}, inst::LLVM.CallInst)
236236
if !valid_function_pointer(job, ptr)
237237
# look it up in the Julia JIT cache
238238
frames = ccall(:jl_lookup_code_address, Any, (Ptr{Cvoid}, Cint,), ptr, 0)
239-
if length(frames) >= 1
240-
@compiler_assert length(frames) == 1 job frames=frames
239+
# XXX: what if multiple frames are returned? rare, but happens
240+
if length(frames) == 1
241241
fn, file, line, linfo, fromC, inlined = last(frames)
242242
push!(errors, (POINTER_FUNCTION, bt, fn))
243243
else

0 commit comments

Comments
 (0)