@@ -263,6 +263,7 @@ template <typename ContextT> class GenericSyncDependenceAnalysis {
263263public:
264264 using BlockT = typename ContextT::BlockT;
265265 using DominatorTreeT = typename ContextT::DominatorTreeT;
266+ using PostDominatorTreeT = typename ContextT::PostDominatorTreeT;
266267 using FunctionT = typename ContextT::FunctionT;
267268 using ValueRefT = typename ContextT::ValueRefT;
268269 using InstructionT = typename ContextT::InstructionT;
@@ -296,7 +297,9 @@ template <typename ContextT> class GenericSyncDependenceAnalysis {
296297 using DivergencePropagatorT = DivergencePropagator<ContextT>;
297298
298299 GenericSyncDependenceAnalysis (const ContextT &Context,
299- const DominatorTreeT &DT, const CycleInfoT &CI);
300+ const DominatorTreeT &DT,
301+ const PostDominatorTreeT &PDT,
302+ const CycleInfoT &CI);
300303
301304 // / \brief Computes divergent join points and cycle exits caused by branch
302305 // / divergence in \p Term.
@@ -315,6 +318,7 @@ template <typename ContextT> class GenericSyncDependenceAnalysis {
315318 ModifiedPO CyclePO;
316319
317320 const DominatorTreeT &DT;
321+ const PostDominatorTreeT &PDT;
318322 const CycleInfoT &CI;
319323
320324 DenseMap<const BlockT *, std::unique_ptr<DivergenceDescriptor>>
@@ -336,6 +340,7 @@ template <typename ContextT> class GenericUniformityAnalysisImpl {
336340 using UseT = typename ContextT::UseT;
337341 using InstructionT = typename ContextT::InstructionT;
338342 using DominatorTreeT = typename ContextT::DominatorTreeT;
343+ using PostDominatorTreeT = typename ContextT::PostDominatorTreeT;
339344
340345 using CycleInfoT = GenericCycleInfo<ContextT>;
341346 using CycleT = typename CycleInfoT::CycleT;
@@ -348,10 +353,12 @@ template <typename ContextT> class GenericUniformityAnalysisImpl {
348353 using TemporalDivergenceTuple =
349354 std::tuple<ConstValueRefT, InstructionT *, const CycleT *>;
350355
351- GenericUniformityAnalysisImpl (const DominatorTreeT &DT, const CycleInfoT &CI,
356+ GenericUniformityAnalysisImpl (const DominatorTreeT &DT,
357+ const PostDominatorTreeT &PDT,
358+ const CycleInfoT &CI,
352359 const TargetTransformInfo *TTI)
353360 : Context(CI.getSSAContext()), F(*Context.getFunction()), CI(CI),
354- TTI (TTI), DT(DT), SDA(Context, DT, CI) {}
361+ TTI (TTI), DT(DT), PDT(PDT), SDA(Context, DT, PDT , CI) {}
355362
356363 void initialize ();
357364
@@ -435,6 +442,7 @@ template <typename ContextT> class GenericUniformityAnalysisImpl {
435442
436443private:
437444 const DominatorTreeT &DT;
445+ const PostDominatorTreeT &PDT;
438446
439447 // Recognized cycles with divergent exits.
440448 SmallPtrSet<const CycleT *, 16 > DivergentExitCycles;
@@ -493,6 +501,7 @@ template <typename ContextT> class DivergencePropagator {
493501public:
494502 using BlockT = typename ContextT::BlockT;
495503 using DominatorTreeT = typename ContextT::DominatorTreeT;
504+ using PostDominatorTreeT = typename ContextT::PostDominatorTreeT;
496505 using FunctionT = typename ContextT::FunctionT;
497506 using ValueRefT = typename ContextT::ValueRefT;
498507
@@ -507,6 +516,7 @@ template <typename ContextT> class DivergencePropagator {
507516
508517 const ModifiedPO &CyclePOT;
509518 const DominatorTreeT &DT;
519+ const PostDominatorTreeT &PDT;
510520 const CycleInfoT &CI;
511521 const BlockT &DivTermBlock;
512522 const ContextT &Context;
@@ -522,10 +532,11 @@ template <typename ContextT> class DivergencePropagator {
522532 BlockLabelMapT &BlockLabels;
523533
524534 DivergencePropagator (const ModifiedPO &CyclePOT, const DominatorTreeT &DT,
525- const CycleInfoT &CI, const BlockT &DivTermBlock)
526- : CyclePOT(CyclePOT), DT(DT), CI(CI), DivTermBlock(DivTermBlock),
527- Context (CI.getSSAContext()), DivDesc(new DivergenceDescriptorT),
528- BlockLabels(DivDesc->BlockLabels) {}
535+ const PostDominatorTreeT &PDT, const CycleInfoT &CI,
536+ const BlockT &DivTermBlock)
537+ : CyclePOT(CyclePOT), DT(DT), PDT(PDT), CI(CI),
538+ DivTermBlock (DivTermBlock), Context(CI.getSSAContext()),
539+ DivDesc(new DivergenceDescriptorT), BlockLabels(DivDesc->BlockLabels) {}
529540
530541 void printDefs (raw_ostream &Out) {
531542 Out << " Propagator::BlockLabels {\n " ;
@@ -542,6 +553,12 @@ template <typename ContextT> class DivergencePropagator {
542553 Out << " }\n " ;
543554 }
544555
556+ const BlockT *getIPDom (const BlockT *B) {
557+ const auto *Node = PDT.getNode (B);
558+ const auto *IPDomNode = Node->getIDom ();
559+ return IPDomNode->getBlock ();
560+ }
561+
545562 // Push a definition (\p PushedLabel) to \p SuccBlock and return whether this
546563 // causes a divergent join.
547564 bool computeJoin (const BlockT &SuccBlock, const BlockT &PushedLabel) {
@@ -610,10 +627,11 @@ template <typename ContextT> class DivergencePropagator {
610627 LLVM_DEBUG (dbgs () << " SDA:computeJoinPoints: "
611628 << Context.print (&DivTermBlock) << " \n " );
612629
613- // Early stopping criterion
614- int FloorIdx = CyclePOT.size () - 1 ;
615- const BlockT *FloorLabel = nullptr ;
616- int DivTermIdx = CyclePOT.getIndex (&DivTermBlock);
630+ // Immediate Post-dominator of DivTermBlock is the last join
631+ // to visit.
632+ const auto *ImmPDom = getIPDom (&DivTermBlock);
633+
634+ LLVM_DEBUG (dbgs () << " Last join: " << Context.print (ImmPDom) << " \n " );
617635
618636 // Bootstrap with branch targets
619637 auto const *DivTermCycle = CI.getCycle (&DivTermBlock);
@@ -626,34 +644,29 @@ template <typename ContextT> class DivergencePropagator {
626644 LLVM_DEBUG (dbgs () << " \t Immediate divergent cycle exit: "
627645 << Context.print (SuccBlock) << " \n " );
628646 }
629- auto SuccIdx = CyclePOT.getIndex (SuccBlock);
630647 visitEdge (*SuccBlock, *SuccBlock);
631- FloorIdx = std::min<int >(FloorIdx, SuccIdx);
632648 }
633649
634650 while (true ) {
635651 auto BlockIdx = FreshLabels.find_last ();
636- if (BlockIdx == -1 || BlockIdx < FloorIdx )
652+ if (BlockIdx == -1 )
637653 break ;
638654
639655 LLVM_DEBUG (dbgs () << " Current labels:\n " ; printDefs (dbgs ()));
640656
641657 FreshLabels.reset (BlockIdx);
642- if (BlockIdx == DivTermIdx) {
643- LLVM_DEBUG (dbgs () << " Skipping DivTermBlock\n " );
658+ const auto *Block = CyclePOT[BlockIdx];
659+ if (Block == ImmPDom) {
660+ LLVM_DEBUG (dbgs () << " Skipping ImmPDom\n " );
644661 continue ;
645662 }
646663
647- const auto *Block = CyclePOT[BlockIdx];
648664 LLVM_DEBUG (dbgs () << " visiting " << Context.print (Block) << " at index "
649665 << BlockIdx << " \n " );
650666
651667 const auto *Label = BlockLabels[Block];
652668 assert (Label);
653669
654- bool CausedJoin = false ;
655- int LoweredFloorIdx = FloorIdx;
656-
657670 // If the current block is the header of a reducible cycle that
658671 // contains the divergent branch, then the label should be
659672 // propagated to the cycle exits. Such a header is the "last
@@ -681,28 +694,11 @@ template <typename ContextT> class DivergencePropagator {
681694 if (const auto *BlockCycle = getReducibleParent (Block)) {
682695 SmallVector<BlockT *, 4 > BlockCycleExits;
683696 BlockCycle->getExitBlocks (BlockCycleExits);
684- for (auto *BlockCycleExit : BlockCycleExits) {
685- CausedJoin |= visitCycleExitEdge (*BlockCycleExit, *Label);
686- LoweredFloorIdx =
687- std::min<int >(LoweredFloorIdx, CyclePOT.getIndex (BlockCycleExit));
688- }
697+ for (auto *BlockCycleExit : BlockCycleExits)
698+ visitCycleExitEdge (*BlockCycleExit, *Label);
689699 } else {
690- for (const auto *SuccBlock : successors (Block)) {
691- CausedJoin |= visitEdge (*SuccBlock, *Label);
692- LoweredFloorIdx =
693- std::min<int >(LoweredFloorIdx, CyclePOT.getIndex (SuccBlock));
694- }
695- }
696-
697- // Floor update
698- if (CausedJoin) {
699- // 1. Different labels pushed to successors
700- FloorIdx = LoweredFloorIdx;
701- } else if (FloorLabel != Label) {
702- // 2. No join caused BUT we pushed a label that is different than the
703- // last pushed label
704- FloorIdx = LoweredFloorIdx;
705- FloorLabel = Label;
700+ for (const auto *SuccBlock : successors (Block))
701+ visitEdge (*SuccBlock, *Label);
706702 }
707703 }
708704
@@ -742,8 +738,9 @@ typename llvm::GenericSyncDependenceAnalysis<ContextT>::DivergenceDescriptor
742738
743739template <typename ContextT>
744740llvm::GenericSyncDependenceAnalysis<ContextT>::GenericSyncDependenceAnalysis(
745- const ContextT &Context, const DominatorTreeT &DT, const CycleInfoT &CI)
746- : CyclePO(Context), DT(DT), CI(CI) {
741+ const ContextT &Context, const DominatorTreeT &DT,
742+ const PostDominatorTreeT &PDT, const CycleInfoT &CI)
743+ : CyclePO(Context), DT(DT), PDT(PDT), CI(CI) {
747744 CyclePO.compute (CI);
748745}
749746
@@ -761,7 +758,7 @@ auto llvm::GenericSyncDependenceAnalysis<ContextT>::getJoinBlocks(
761758 return *ItCached->second ;
762759
763760 // compute all join points
764- DivergencePropagatorT Propagator (CyclePO, DT, CI, *DivTermBlock);
761+ DivergencePropagatorT Propagator (CyclePO, DT, PDT, CI, *DivTermBlock);
765762 auto DivDesc = Propagator.computeJoinPoints ();
766763
767764 auto printBlockSet = [&](ConstBlockSet &Blocks) {
@@ -1155,9 +1152,9 @@ bool GenericUniformityAnalysisImpl<ContextT>::isAlwaysUniform(
11551152
11561153template <typename ContextT>
11571154GenericUniformityInfo<ContextT>::GenericUniformityInfo(
1158- const DominatorTreeT &DT, const CycleInfoT &CI ,
1159- const TargetTransformInfo *TTI) {
1160- DA.reset (new ImplT{DT, CI, TTI});
1155+ const DominatorTreeT &DT, const PostDominatorTreeT &PDT ,
1156+ const CycleInfoT &CI, const TargetTransformInfo *TTI) {
1157+ DA.reset (new ImplT{DT, PDT, CI, TTI});
11611158}
11621159
11631160template <typename ContextT>
0 commit comments