@@ -605,7 +605,39 @@ function add_kernel_state!(@nospecialize(job::CompilerJob), mod::LLVM.Module,
605
605
workmap[f] = new_f
606
606
end
607
607
608
- # clone and rewrite the function bodies
608
+ # clone and rewrite the function bodies, replacing uses of the old stateless function
609
+ # with the newly created definition that includes the state argument.
610
+ #
611
+ # most uses are rewritten by LLVM by putting the functions in the value map.
612
+ # a separate value materializer is used to recreate constant expressions.
613
+ #
614
+ # note that this only _replaces_ the uses of these functions, we'll still need to
615
+ # _correct_ the uses (i.e. actually add the state argument) afterwards.
616
+ function materializer (val)
617
+ if val isa ConstantExpr
618
+ if opcode (val) == LLVM. API. LLVMBitCast
619
+ target = operands (val)[1 ]
620
+ if target isa LLVM. Function && haskey (workmap, target)
621
+ # the function is being bitcasted to a different function type.
622
+ # we need to mutate that function type to include the state argument,
623
+ # or we'd be invoking the original function in an invalid way.
624
+ #
625
+ # XXX : ptrtoint/inttoptr pairs can also lose the state argument...
626
+ # is all this even sound?
627
+ typ = llvmtype (val):: LLVM.PointerType
628
+ ft = eltype (typ):: LLVM.FunctionType
629
+ new_ft = LLVM. FunctionType (return_type (ft), [T_state, parameters (ft)... ])
630
+ return const_bitcast (workmap[target], LLVM. PointerType (new_ft, addrspace (typ)))
631
+ end
632
+ elseif opcode (val) == LLVM. API. LLVMPtrToInt
633
+ target = operands (val)[1 ]
634
+ if target isa LLVM. Function && haskey (workmap, target)
635
+ return const_ptrtoint (workmap[target], llvmtype (val))
636
+ end
637
+ end
638
+ end
639
+ return val
640
+ end
609
641
for (f, new_f) in workmap
610
642
# use a value mapper for rewriting function arguments
611
643
value_map = Dict {LLVM.Value, LLVM.Value} ()
@@ -614,33 +646,43 @@ function add_kernel_state!(@nospecialize(job::CompilerJob), mod::LLVM.Module,
614
646
value_map[param] = new_param
615
647
end
616
648
617
- value_map[f] = new_f
618
- # XXX : do we want this? we're adding a new arg, after all
619
- clone_into! (new_f, f; value_map,
649
+ # rewrite references to the old function
650
+ merge! (value_map, workmap)
651
+
652
+ clone_into! (new_f, f; value_map, materializer,
620
653
changes= LLVM. API. LLVMCloneFunctionChangeTypeGlobalChanges)
621
654
622
- # we can't remove this function yet, as we might still need to rewrite any called,
623
- # but remove the IR already
655
+ # remove the function IR so that we won't have any uses left after this pass.
624
656
empty! (f)
625
657
end
626
658
627
- # update other uses of the old function, modifying call sites to pass the state argument
628
- # TODO : why isn't this covered by the value mapper above? because we need to add an arg!
629
- # XXX : do this with a value mapper, and a materialize (?) to rewrite calls.
630
- function rewrite_uses! (f, new_f)
659
+ # ensure the old (stateless) functions don't have uses anymore, and remove them
660
+ for f in keys (workmap)
661
+ for use in uses (f)
662
+ val = user (use)
663
+ if val isa ConstantExpr
664
+ # XXX : shouldn't clone_into! remove unused CEs?
665
+ isempty (uses (val)) || error (" old function still has uses (via a constant expr)" )
666
+ LLVM. unsafe_destroy! (val)
667
+ else
668
+ error (" old function still has uses" )
669
+ end
670
+ end
671
+ unsafe_delete! (mod, f)
672
+ end
673
+
674
+ # update uses of the new function, modifying call sites to include the kernel state
675
+ function rewrite_uses! (f)
631
676
# update uses
632
677
Builder (ctx) do builder
633
678
for use in uses (f)
634
679
val = user (use)
635
- if val isa LLVM. CallInst || val isa LLVM. InvokeInst || val isa LLVM. CallBrInst
636
- # NOTE: we unconditionally add the state argument, even if there's no uses,
637
- # assuming we'll perform dead arg elimination during optimization.
638
-
680
+ if val isa LLVM. CallBase && called_value (val) == f
639
681
# forward the state argument
640
682
position! (builder, val)
641
683
state = call! (builder, state_intr, Value[], " state" )
642
684
new_val = if val isa LLVM. CallInst
643
- call! (builder, new_f , [state, arguments (val)... ], operand_bundles (val))
685
+ call! (builder, f , [state, arguments (val)... ], operand_bundles (val))
644
686
else
645
687
# TODO : invoke and callbr
646
688
error (" Rewrite of $(typeof (val)) -based calls is not implemented: $val " )
@@ -650,26 +692,19 @@ function add_kernel_state!(@nospecialize(job::CompilerJob), mod::LLVM.Module,
650
692
replace_uses! (val, new_val)
651
693
@assert isempty (uses (val))
652
694
unsafe_delete! (LLVM. parent (val), val)
653
- elseif val isa ConstantExpr && opcode (val) == LLVM. API. LLVMPtrToInt
654
- # XXX : are these safe? we're not adding an arg
655
- # XXX : in addition, won't this RAUW assert in debug mode?
656
- replace_uses! (val, const_ptrtoint (new_f, llvmtype (val)))
657
- LLVM. unsafe_destroy! (val)
658
- elseif val isa ConstantExpr && opcode (val) == LLVM. API. LLVMBitCast
659
- # XXX : are these safe? we're not adding an arg
660
- # XXX : in addition, won't this RAUW assert in debug mode?
661
- replace_uses! (val, const_bitcast (new_f, llvmtype (val)))
662
- LLVM. unsafe_destroy! (val)
695
+ elseif val isa LLVM. CallBase
696
+ # the function is being passed as an argument, which we'll just permit,
697
+ # because we expect to have rewritten the call down the line separately.
698
+ elseif val isa ConstantExpr
699
+ rewrite_uses! (val)
663
700
else
664
701
error (" Cannot rewrite unknown use of function: $val " )
665
702
end
666
703
end
667
704
end
668
705
end
669
- for (f, new_f) in workmap
670
- rewrite_uses! (f, new_f)
671
- @assert isempty (uses (f))
672
- unsafe_delete! (mod, f)
706
+ for f in values (workmap)
707
+ rewrite_uses! (f)
673
708
end
674
709
675
710
return true
0 commit comments