Skip to content

Commit f08aeb3

Browse files
committed
Properly rewrite values and constant expressions.
Doing more with the CloneFunction APIs, and making sure we rewrite uses of constant expressions.
1 parent f8a6c4e commit f08aeb3

File tree

1 file changed

+64
-29
lines changed

1 file changed

+64
-29
lines changed

src/irgen.jl

Lines changed: 64 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -605,7 +605,39 @@ function add_kernel_state!(@nospecialize(job::CompilerJob), mod::LLVM.Module,
605605
workmap[f] = new_f
606606
end
607607

608-
# clone and rewrite the function bodies
608+
# clone and rewrite the function bodies, replacing uses of the old stateless function
609+
# with the newly created definition that includes the state argument.
610+
#
611+
# most uses are rewritten by LLVM by putting the functions in the value map.
612+
# a separate value materializer is used to recreate constant expressions.
613+
#
614+
# note that this only _replaces_ the uses of these functions, we'll still need to
615+
# _correct_ the uses (i.e. actually add the state argument) afterwards.
616+
function materializer(val)
617+
if val isa ConstantExpr
618+
if opcode(val) == LLVM.API.LLVMBitCast
619+
target = operands(val)[1]
620+
if target isa LLVM.Function && haskey(workmap, target)
621+
# the function is being bitcasted to a different function type.
622+
# we need to mutate that function type to include the state argument,
623+
# or we'd be invoking the original function in an invalid way.
624+
#
625+
# XXX: ptrtoint/inttoptr pairs can also lose the state argument...
626+
# is all this even sound?
627+
typ = llvmtype(val)::LLVM.PointerType
628+
ft = eltype(typ)::LLVM.FunctionType
629+
new_ft = LLVM.FunctionType(return_type(ft), [T_state, parameters(ft)...])
630+
return const_bitcast(workmap[target], LLVM.PointerType(new_ft, addrspace(typ)))
631+
end
632+
elseif opcode(val) == LLVM.API.LLVMPtrToInt
633+
target = operands(val)[1]
634+
if target isa LLVM.Function && haskey(workmap, target)
635+
return const_ptrtoint(workmap[target], llvmtype(val))
636+
end
637+
end
638+
end
639+
return val
640+
end
609641
for (f, new_f) in workmap
610642
# use a value mapper for rewriting function arguments
611643
value_map = Dict{LLVM.Value, LLVM.Value}()
@@ -614,33 +646,43 @@ function add_kernel_state!(@nospecialize(job::CompilerJob), mod::LLVM.Module,
614646
value_map[param] = new_param
615647
end
616648

617-
value_map[f] = new_f
618-
# XXX: do we want this? we're adding a new arg, after all
619-
clone_into!(new_f, f; value_map,
649+
# rewrite references to the old function
650+
merge!(value_map, workmap)
651+
652+
clone_into!(new_f, f; value_map, materializer,
620653
changes=LLVM.API.LLVMCloneFunctionChangeTypeGlobalChanges)
621654

622-
# we can't remove this function yet, as we might still need to rewrite any called,
623-
# but remove the IR already
655+
# remove the function IR so that we won't have any uses left after this pass.
624656
empty!(f)
625657
end
626658

627-
# update other uses of the old function, modifying call sites to pass the state argument
628-
# TODO: why isn't this covered by the value mapper above? because we need to add an arg!
629-
# XXX: do this with a value mapper, and a materialize (?) to rewrite calls.
630-
function rewrite_uses!(f, new_f)
659+
# ensure the old (stateless) functions don't have uses anymore, and remove them
660+
for f in keys(workmap)
661+
for use in uses(f)
662+
val = user(use)
663+
if val isa ConstantExpr
664+
# XXX: shouldn't clone_into! remove unused CEs?
665+
isempty(uses(val)) || error("old function still has uses (via a constant expr)")
666+
LLVM.unsafe_destroy!(val)
667+
else
668+
error("old function still has uses")
669+
end
670+
end
671+
unsafe_delete!(mod, f)
672+
end
673+
674+
# update uses of the new function, modifying call sites to include the kernel state
675+
function rewrite_uses!(f)
631676
# update uses
632677
Builder(ctx) do builder
633678
for use in uses(f)
634679
val = user(use)
635-
if val isa LLVM.CallInst || val isa LLVM.InvokeInst || val isa LLVM.CallBrInst
636-
# NOTE: we unconditionally add the state argument, even if there's no uses,
637-
# assuming we'll perform dead arg elimination during optimization.
638-
680+
if val isa LLVM.CallBase && called_value(val) == f
639681
# forward the state argument
640682
position!(builder, val)
641683
state = call!(builder, state_intr, Value[], "state")
642684
new_val = if val isa LLVM.CallInst
643-
call!(builder, new_f, [state, arguments(val)...], operand_bundles(val))
685+
call!(builder, f, [state, arguments(val)...], operand_bundles(val))
644686
else
645687
# TODO: invoke and callbr
646688
error("Rewrite of $(typeof(val))-based calls is not implemented: $val")
@@ -650,26 +692,19 @@ function add_kernel_state!(@nospecialize(job::CompilerJob), mod::LLVM.Module,
650692
replace_uses!(val, new_val)
651693
@assert isempty(uses(val))
652694
unsafe_delete!(LLVM.parent(val), val)
653-
elseif val isa ConstantExpr && opcode(val) == LLVM.API.LLVMPtrToInt
654-
# XXX: are these safe? we're not adding an arg
655-
# XXX: in addition, won't this RAUW assert in debug mode?
656-
replace_uses!(val, const_ptrtoint(new_f, llvmtype(val)))
657-
LLVM.unsafe_destroy!(val)
658-
elseif val isa ConstantExpr && opcode(val) == LLVM.API.LLVMBitCast
659-
# XXX: are these safe? we're not adding an arg
660-
# XXX: in addition, won't this RAUW assert in debug mode?
661-
replace_uses!(val, const_bitcast(new_f, llvmtype(val)))
662-
LLVM.unsafe_destroy!(val)
695+
elseif val isa LLVM.CallBase
696+
# the function is being passed as an argument, which we'll just permit,
697+
# because we expect to have rewritten the call down the line separately.
698+
elseif val isa ConstantExpr
699+
rewrite_uses!(val)
663700
else
664701
error("Cannot rewrite unknown use of function: $val")
665702
end
666703
end
667704
end
668705
end
669-
for (f, new_f) in workmap
670-
rewrite_uses!(f, new_f)
671-
@assert isempty(uses(f))
672-
unsafe_delete!(mod, f)
706+
for f in values(workmap)
707+
rewrite_uses!(f)
673708
end
674709

675710
return true

0 commit comments

Comments
 (0)