@@ -563,11 +563,22 @@ function add_kernel_state!(@nospecialize(job::CompilerJob), mod::LLVM.Module,
563
563
# iteratively discover functions that use the intrinsic or any function calling it
564
564
worklist_length = length (worklist)
565
565
additions = LLVM. Function[]
566
+ function check_user (val)
567
+ if val isa Instruction
568
+ bb = LLVM. parent (val)
569
+ new_f = LLVM. parent (bb)
570
+ in (new_f, worklist) || push! (additions, new_f)
571
+ elseif val isa ConstantExpr
572
+ # constant expressions don't have a parent; we need to look up their uses
573
+ for use in uses (val)
574
+ check_user (user (use))
575
+ end
576
+ else
577
+ error (" Don't know how to check uses of $val . Please file an issue." )
578
+ end
579
+ end
566
580
for f in worklist, use in uses (f)
567
- inst = user (use):: Instruction
568
- bb = LLVM. parent (inst)
569
- new_f = LLVM. parent (bb)
570
- in (new_f, worklist) || push! (additions, new_f)
581
+ check_user (user (use))
571
582
end
572
583
for f in additions
573
584
push! (worklist, f)
@@ -595,7 +606,39 @@ function add_kernel_state!(@nospecialize(job::CompilerJob), mod::LLVM.Module,
595
606
workmap[f] = new_f
596
607
end
597
608
598
- # clone and rewrite the function bodies
609
+ # clone and rewrite the function bodies, replacing uses of the old stateless function
610
+ # with the newly created definition that includes the state argument.
611
+ #
612
+ # most uses are rewritten by LLVM by putting the functions in the value map.
613
+ # a separate value materializer is used to recreate constant expressions.
614
+ #
615
+ # note that this only _replaces_ the uses of these functions, we'll still need to
616
+ # _correct_ the uses (i.e. actually add the state argument) afterwards.
617
+ function materializer (val)
618
+ if val isa ConstantExpr
619
+ if opcode (val) == LLVM. API. LLVMBitCast
620
+ target = operands (val)[1 ]
621
+ if target isa LLVM. Function && haskey (workmap, target)
622
+ # the function is being bitcasted to a different function type.
623
+ # we need to mutate that function type to include the state argument,
624
+ # or we'd be invoking the original function in an invalid way.
625
+ #
626
+ # XXX : ptrtoint/inttoptr pairs can also lose the state argument...
627
+ # is all this even sound?
628
+ typ = llvmtype (val):: LLVM.PointerType
629
+ ft = eltype (typ):: LLVM.FunctionType
630
+ new_ft = LLVM. FunctionType (return_type (ft), [T_state, parameters (ft)... ])
631
+ return const_bitcast (workmap[target], LLVM. PointerType (new_ft, addrspace (typ)))
632
+ end
633
+ elseif opcode (val) == LLVM. API. LLVMPtrToInt
634
+ target = operands (val)[1 ]
635
+ if target isa LLVM. Function && haskey (workmap, target)
636
+ return const_ptrtoint (workmap[target], llvmtype (val))
637
+ end
638
+ end
639
+ end
640
+ return val
641
+ end
599
642
for (f, new_f) in workmap
600
643
# use a value mapper for rewriting function arguments
601
644
value_map = Dict {LLVM.Value, LLVM.Value} ()
@@ -604,30 +647,54 @@ function add_kernel_state!(@nospecialize(job::CompilerJob), mod::LLVM.Module,
604
647
value_map[param] = new_param
605
648
end
606
649
607
- value_map[f] = new_f
608
- clone_into! (new_f, f; value_map,
650
+ # rewrite references to the old function
651
+ merge! (value_map, workmap)
652
+
653
+ clone_into! (new_f, f; value_map, materializer,
609
654
changes= LLVM. API. LLVMCloneFunctionChangeTypeGlobalChanges)
610
655
611
- # we can't remove this function yet, as we might still need to rewrite any called,
612
- # but remove the IR already
656
+ # remove the function IR so that we won't have any uses left after this pass.
613
657
empty! (f)
614
658
end
615
659
616
- # update other uses of the old function, modifying call sites to pass the state argument
617
- function rewrite_uses! (f, new_f)
660
+ # ensure the old (stateless) functions don't have uses anymore, and remove them
661
+ for f in keys (workmap)
662
+ for use in uses (f)
663
+ val = user (use)
664
+ if val isa ConstantExpr
665
+ # XXX : shouldn't clone_into! remove unused CEs?
666
+ isempty (uses (val)) || error (" old function still has uses (via a constant expr)" )
667
+ LLVM. unsafe_destroy! (val)
668
+ else
669
+ error (" old function still has uses" )
670
+ end
671
+ end
672
+ unsafe_delete! (mod, f)
673
+ end
674
+
675
+ # update uses of the new function, modifying call sites to include the kernel state
676
+ function rewrite_uses! (f)
618
677
# update uses
619
678
Builder (ctx) do builder
620
679
for use in uses (f)
621
680
val = user (use)
622
- if val isa LLVM. CallInst || val isa LLVM. InvokeInst || val isa LLVM. CallBrInst
623
- # NOTE: we unconditionally add the state argument, even if there's no uses,
624
- # assuming we'll perform dead arg elimination during optimization.
681
+ if val isa LLVM. CallBase && called_value (val) == f
682
+ # NOTE: we don't rewrite calls using Julia's jlcall calling convention,
683
+ # as those have a fixed argument list, passing actual arguments
684
+ # in an array of objects. that doesn't matter, for now, since
685
+ # GPU back-ends don't support such calls anyhow. but if we ever
686
+ # want to support kernel state passing on more capable back-ends,
687
+ # we'll need to update the argument array instead.
688
+ if callconv (val) == 37 || callconv (val) == 38
689
+ # TODO : update for LLVM 15 when JuliaLang/julia#45088 is merged.
690
+ continue
691
+ end
625
692
626
693
# forward the state argument
627
694
position! (builder, val)
628
695
state = call! (builder, state_intr, Value[], " state" )
629
696
new_val = if val isa LLVM. CallInst
630
- call! (builder, new_f , [state, arguments (val)... ], operand_bundles (val))
697
+ call! (builder, f , [state, arguments (val)... ], operand_bundles (val))
631
698
else
632
699
# TODO : invoke and callbr
633
700
error (" Rewrite of $(typeof (val)) -based calls is not implemented: $val " )
@@ -637,16 +704,19 @@ function add_kernel_state!(@nospecialize(job::CompilerJob), mod::LLVM.Module,
637
704
replace_uses! (val, new_val)
638
705
@assert isempty (uses (val))
639
706
unsafe_delete! (LLVM. parent (val), val)
707
+ elseif val isa LLVM. CallBase
708
+ # the function is being passed as an argument, which we'll just permit,
709
+ # because we expect to have rewritten the call down the line separately.
710
+ elseif val isa ConstantExpr
711
+ rewrite_uses! (val)
640
712
else
641
713
error (" Cannot rewrite unknown use of function: $val " )
642
714
end
643
715
end
644
716
end
645
717
end
646
- for (f, new_f) in workmap
647
- rewrite_uses! (f, new_f)
648
- @assert isempty (uses (f))
649
- unsafe_delete! (mod, f)
718
+ for f in values (workmap)
719
+ rewrite_uses! (f)
650
720
end
651
721
652
722
return true
0 commit comments