28
28
#include " llvm/Transforms/IPO/Attributor.h"
29
29
#include " llvm/Transforms/Utils/BasicBlockUtils.h"
30
30
#include " llvm/Transforms/Utils/CallGraphUpdater.h"
31
+ #include " llvm/Transforms/Utils/CodeExtractor.h"
31
32
32
33
using namespace llvm ;
33
34
using namespace omp ;
@@ -317,13 +318,17 @@ struct OMPInformationCache : public InformationCache {
317
318
return NumUses;
318
319
}
319
320
321
+ // Helper function to recollect uses of a runtime function.
322
+ void recollectUsesForFunction (RuntimeFunction RTF) {
323
+ auto &RFI = RFIs[RTF];
324
+ RFI.clearUsesMap ();
325
+ collectUses (RFI, /* CollectStats*/ false );
326
+ }
327
+
320
328
// Helper function to recollect uses of all runtime functions.
321
329
void recollectUses () {
322
- for (int Idx = 0 ; Idx < RFIs.size (); ++Idx) {
323
- auto &RFI = RFIs[static_cast <RuntimeFunction>(Idx)];
324
- RFI.clearUsesMap ();
325
- collectUses (RFI, /* CollectStats*/ false );
326
- }
330
+ for (int Idx = 0 ; Idx < RFIs.size (); ++Idx)
331
+ recollectUsesForFunction (static_cast <RuntimeFunction>(Idx));
327
332
}
328
333
329
334
// / Helper to initialize all runtime function information for those defined
@@ -601,15 +606,11 @@ struct OpenMPOpt {
601
606
if (!RFI.Declaration )
602
607
return false ;
603
608
604
- // Check if there any __kmpc_push_proc_bind calls for explicit affinities.
605
- OMPInformationCache::RuntimeFunctionInfo &ProcBindRFI =
606
- OMPInfoCache.RFIs [OMPRTL___kmpc_push_proc_bind];
607
-
608
- // Defensively abort if explicit affinities are set.
609
- // TODO: Track ICV proc_bind to merge when mergable regions have the same
610
- // affinity.
611
- if (ProcBindRFI.Declaration )
612
- return false ;
609
+ // Unmergable calls that prevent merging a parallel region.
610
+ OMPInformationCache::RuntimeFunctionInfo UnmergableCallsInfo[] = {
611
+ OMPInfoCache.RFIs [OMPRTL___kmpc_push_proc_bind],
612
+ OMPInfoCache.RFIs [OMPRTL___kmpc_push_num_threads],
613
+ };
613
614
614
615
bool Changed = false ;
615
616
LoopInfo *LI = nullptr ;
@@ -637,6 +638,90 @@ struct OpenMPOpt {
637
638
638
639
auto FiniCB = [&](InsertPointTy CodeGenIP) {};
639
640
641
+ // / Create a sequential execution region within a merged parallel region,
642
+ // / encapsulated in a master construct with a barrier for synchronization.
643
+ auto CreateSequentialRegion = [&](Function *OuterFn,
644
+ BasicBlock *OuterPredBB,
645
+ Instruction *SeqStartI,
646
+ Instruction *SeqEndI) {
647
+ // Isolate the instructions of the sequential region to a separate
648
+ // block.
649
+ BasicBlock *ParentBB = SeqStartI->getParent ();
650
+ BasicBlock *SeqEndBB =
651
+ SplitBlock (ParentBB, SeqEndI->getNextNode (), DT, LI);
652
+ BasicBlock *SeqAfterBB =
653
+ SplitBlock (SeqEndBB, &*SeqEndBB->getFirstInsertionPt (), DT, LI);
654
+ BasicBlock *SeqStartBB =
655
+ SplitBlock (ParentBB, SeqStartI, DT, LI, nullptr , " seq.par.merged" );
656
+
657
+ assert (ParentBB->getUniqueSuccessor () == SeqStartBB &&
658
+ " Expected a different CFG" );
659
+ const DebugLoc DL = ParentBB->getTerminator ()->getDebugLoc ();
660
+ ParentBB->getTerminator ()->eraseFromParent ();
661
+
662
+ auto BodyGenCB = [&](InsertPointTy AllocaIP, InsertPointTy CodeGenIP,
663
+ BasicBlock &ContinuationIP) {
664
+ BasicBlock *CGStartBB = CodeGenIP.getBlock ();
665
+ BasicBlock *CGEndBB =
666
+ SplitBlock (CGStartBB, &*CodeGenIP.getPoint (), DT, LI);
667
+ assert (SeqStartBB != nullptr && " SeqStartBB should not be null" );
668
+ CGStartBB->getTerminator ()->setSuccessor (0 , SeqStartBB);
669
+ assert (SeqEndBB != nullptr && " SeqEndBB should not be null" );
670
+ SeqEndBB->getTerminator ()->setSuccessor (0 , CGEndBB);
671
+ };
672
+ auto FiniCB = [&](InsertPointTy CodeGenIP) {};
673
+
674
+ // Find outputs from the sequential region to outside users and
675
+ // broadcast their values to them.
676
+ for (Instruction &I : *SeqStartBB) {
677
+ SmallPtrSet<Instruction *, 4 > OutsideUsers;
678
+ for (User *Usr : I.users ()) {
679
+ Instruction &UsrI = *cast<Instruction>(Usr);
680
+ // Ignore outputs to LT intrinsics, code extraction for the merged
681
+ // parallel region will fix them.
682
+ if (UsrI.isLifetimeStartOrEnd ())
683
+ continue ;
684
+
685
+ if (UsrI.getParent () != SeqStartBB)
686
+ OutsideUsers.insert (&UsrI);
687
+ }
688
+
689
+ if (OutsideUsers.empty ())
690
+ continue ;
691
+
692
+ // Emit an alloca in the outer region to store the broadcasted
693
+ // value.
694
+ const DataLayout &DL = M.getDataLayout ();
695
+ AllocaInst *AllocaI = new AllocaInst (
696
+ I.getType (), DL.getAllocaAddrSpace (), nullptr ,
697
+ I.getName () + " .seq.output.alloc" , &OuterFn->front ().front ());
698
+
699
+ // Emit a store instruction in the sequential BB to update the
700
+ // value.
701
+ new StoreInst (&I, AllocaI, SeqStartBB->getTerminator ());
702
+
703
+ // Emit a load instruction and replace the use of the output value
704
+ // with it.
705
+ for (Instruction *UsrI : OutsideUsers) {
706
+ LoadInst *LoadI = new LoadInst (I.getType (), AllocaI,
707
+ I.getName () + " .seq.output.load" , UsrI);
708
+ UsrI->replaceUsesOfWith (&I, LoadI);
709
+ }
710
+ }
711
+
712
+ OpenMPIRBuilder::LocationDescription Loc (
713
+ InsertPointTy (ParentBB, ParentBB->end ()), DL);
714
+ InsertPointTy SeqAfterIP =
715
+ OMPInfoCache.OMPBuilder .createMaster (Loc, BodyGenCB, FiniCB);
716
+
717
+ OMPInfoCache.OMPBuilder .createBarrier (SeqAfterIP, OMPD_parallel);
718
+
719
+ BranchInst::Create (SeqAfterBB, SeqAfterIP.getBlock ());
720
+
721
+ LLVM_DEBUG (dbgs () << TAG << " After sequential inlining " << *OuterFn
722
+ << " \n " );
723
+ };
724
+
640
725
// Helper to merge the __kmpc_fork_call calls in MergableCIs. They are all
641
726
// contained in BB and only separated by instructions that can be
642
727
// redundantly executed in parallel. The block BB is split before the first
@@ -682,6 +767,21 @@ struct OpenMPOpt {
682
767
const DebugLoc DL = BB->getTerminator ()->getDebugLoc ();
683
768
BB->getTerminator ()->eraseFromParent ();
684
769
770
+ // Create sequential regions for sequential instructions that are
771
+ // in-between mergable parallel regions.
772
+ for (auto *It = MergableCIs.begin (), *End = MergableCIs.end () - 1 ;
773
+ It != End; ++It) {
774
+ Instruction *ForkCI = *It;
775
+ Instruction *NextForkCI = *(It + 1 );
776
+
777
+ // Continue if there are not in-between instructions.
778
+ if (ForkCI->getNextNode () == NextForkCI)
779
+ continue ;
780
+
781
+ CreateSequentialRegion (OriginalFn, BB, ForkCI->getNextNode (),
782
+ NextForkCI->getPrevNode ());
783
+ }
784
+
685
785
OpenMPIRBuilder::LocationDescription Loc (InsertPointTy (BB, BB->end ()),
686
786
DL);
687
787
IRBuilder<>::InsertPoint AllocaIP (
@@ -695,7 +795,7 @@ struct OpenMPOpt {
695
795
BranchInst::Create (AfterBB, AfterIP.getBlock ());
696
796
697
797
// Perform the actual outlining.
698
- OMPInfoCache.OMPBuilder .finalize ();
798
+ OMPInfoCache.OMPBuilder .finalize (/* AllowExtractorSinking */ true );
699
799
700
800
Function *OutlinedFn = MergableCIs.front ()->getCaller ();
701
801
@@ -782,16 +882,75 @@ struct OpenMPOpt {
782
882
BasicBlock *BB = It.getFirst ();
783
883
SmallVector<CallInst *, 4 > MergableCIs;
784
884
885
+ // / Returns true if the instruction is mergable, false otherwise.
886
+ // / A terminator instruction is unmergable by definition since merging
887
+ // / works within a BB. Instructions before the mergable region are
888
+ // / mergable if they are not calls to OpenMP runtime functions that may
889
+ // / set different execution parameters for subsequent parallel regions.
890
+ // / Instructions in-between parallel regions are mergable if they are not
891
+ // / calls to any non-intrinsic function since that may call a non-mergable
892
+ // / OpenMP runtime function.
893
+ auto IsMergable = [&](Instruction &I, bool IsBeforeMergableRegion) {
894
+ // We do not merge across BBs, hence return false (unmergable) if the
895
+ // instruction is a terminator.
896
+ if (I.isTerminator ())
897
+ return false ;
898
+
899
+ if (!isa<CallInst>(&I))
900
+ return true ;
901
+
902
+ CallInst *CI = cast<CallInst>(&I);
903
+ if (IsBeforeMergableRegion) {
904
+ Function *CalledFunction = CI->getCalledFunction ();
905
+ if (!CalledFunction)
906
+ return false ;
907
+ // Return false (unmergable) if the call before the parallel
908
+ // region calls an explicit affinity (proc_bind) or number of
909
+ // threads (num_threads) compiler-generated function. Those settings
910
+ // may be incompatible with following parallel regions.
911
+ // TODO: ICV tracking to detect compatibility.
912
+ for (const auto &RFI : UnmergableCallsInfo) {
913
+ if (CalledFunction == RFI.Declaration )
914
+ return false ;
915
+ }
916
+ } else {
917
+ // Return false (unmergable) if there is a call instruction
918
+ // in-between parallel regions when it is not an intrinsic. It
919
+ // may call an unmergable OpenMP runtime function in its callpath.
920
+ // TODO: Keep track of possible OpenMP calls in the callpath.
921
+ if (!isa<IntrinsicInst>(CI))
922
+ return false ;
923
+ }
924
+
925
+ return true ;
926
+ };
785
927
// Find maximal number of parallel region CIs that are safe to merge.
786
- for (Instruction &I : *BB) {
928
+ for (auto It = BB->begin (), End = BB->end (); It != End;) {
929
+ Instruction &I = *It;
930
+ ++It;
931
+
787
932
if (CIs.count (&I)) {
788
933
MergableCIs.push_back (cast<CallInst>(&I));
789
934
continue ;
790
935
}
791
936
792
- if (isSafeToSpeculativelyExecute (&I, &I, DT))
937
+ // Continue expanding if the instruction is mergable.
938
+ if (IsMergable (I, MergableCIs.empty ()))
793
939
continue ;
794
940
941
+ // Forward the instruction iterator to skip the next parallel region
942
+ // since there is an unmergable instruction which can affect it.
943
+ for (; It != End; ++It) {
944
+ Instruction &SkipI = *It;
945
+ if (CIs.count (&SkipI)) {
946
+ LLVM_DEBUG (dbgs () << TAG << " Skip parallel region " << SkipI
947
+ << " due to " << I << " \n " );
948
+ ++It;
949
+ break ;
950
+ }
951
+ }
952
+
953
+ // Store mergable regions found.
795
954
if (MergableCIs.size () > 1 ) {
796
955
MergableCIsVector.push_back (MergableCIs);
797
956
LLVM_DEBUG (dbgs () << TAG << " Found " << MergableCIs.size ()
@@ -812,15 +971,12 @@ struct OpenMPOpt {
812
971
}
813
972
814
973
if (Changed) {
815
- // Update RFI info to set it up for later passes.
816
- RFI.clearUsesMap ();
817
- OMPInfoCache.collectUses (RFI, /* CollectStats */ false );
818
-
819
- // Collect uses for the emitted barrier call.
820
- OMPInformationCache::RuntimeFunctionInfo &BarrierRFI =
821
- OMPInfoCache.RFIs [OMPRTL___kmpc_barrier];
822
- BarrierRFI.clearUsesMap ();
823
- OMPInfoCache.collectUses (BarrierRFI, /* CollectStats */ false );
974
+ // / Re-collect use for fork calls, emitted barrier calls, and
975
+ // / any emitted master/end_master calls.
976
+ OMPInfoCache.recollectUsesForFunction (OMPRTL___kmpc_fork_call);
977
+ OMPInfoCache.recollectUsesForFunction (OMPRTL___kmpc_barrier);
978
+ OMPInfoCache.recollectUsesForFunction (OMPRTL___kmpc_master);
979
+ OMPInfoCache.recollectUsesForFunction (OMPRTL___kmpc_end_master);
824
980
}
825
981
826
982
return Changed;
0 commit comments