3030#include " llvm/IR/Metadata.h"
3131#include " llvm/IR/PassManager.h"
3232#include " llvm/IR/PatternMatch.h"
33+ #include " llvm/IR/ProfDataUtils.h"
3334#include " llvm/IR/Type.h"
3435#include " llvm/IR/Use.h"
3536#include " llvm/IR/Value.h"
4748#include " llvm/Transforms/Utils/SSAUpdater.h"
4849#include < algorithm>
4950#include < cassert>
51+ #include < optional>
5052#include < utility>
5153
5254using namespace llvm ;
@@ -85,7 +87,46 @@ using PhiMap = MapVector<PHINode *, BBValueVector>;
8587using BB2BBVecMap = MapVector<BasicBlock *, BBVector>;
8688
8789using BBPhiMap = DenseMap<BasicBlock *, PhiMap>;
88- using BBPredicates = DenseMap<BasicBlock *, Value *>;
90+
91+ using MaybeCondBranchWeights = std::optional<class CondBranchWeights >;
92+
93+ class CondBranchWeights {
94+ uint32_t TrueWeight;
95+ uint32_t FalseWeight;
96+
97+ public:
98+ CondBranchWeights (unsigned T, unsigned F) : TrueWeight(T), FalseWeight(F) {}
99+
100+ static MaybeCondBranchWeights tryParse (const BranchInst &Br) {
101+ assert (Br.isConditional ());
102+
103+ SmallVector<uint32_t , 2 > Weights;
104+ if (!extractBranchWeights (Br, Weights))
105+ return std::nullopt ;
106+
107+ if (Weights.size () != 2 )
108+ return std::nullopt ;
109+
110+ return CondBranchWeights{Weights[0 ], Weights[1 ]};
111+ }
112+
113+ static void setMetadata (BranchInst &Br,
114+ MaybeCondBranchWeights const &Weights) {
115+ assert (Br.isConditional ());
116+ if (!Weights)
117+ return ;
118+ uint32_t Arr[] = {Weights->TrueWeight , Weights->FalseWeight };
119+ setBranchWeights (Br, Arr, false );
120+ }
121+
122+ CondBranchWeights invert () const {
123+ return CondBranchWeights{FalseWeight, TrueWeight};
124+ }
125+ };
126+
127+ using ValueWeightPair = std::pair<Value *, MaybeCondBranchWeights>;
128+
129+ using BBPredicates = DenseMap<BasicBlock *, ValueWeightPair>;
89130using PredMap = DenseMap<BasicBlock *, BBPredicates>;
90131using BB2BBMap = DenseMap<BasicBlock *, BasicBlock *>;
91132
@@ -271,7 +312,7 @@ class StructurizeCFG {
271312
272313 void analyzeLoops (RegionNode *N);
273314
274- Value * buildCondition (BranchInst *Term, unsigned Idx, bool Invert);
315+ ValueWeightPair buildCondition (BranchInst *Term, unsigned Idx, bool Invert);
275316
276317 void gatherPredicates (RegionNode *N);
277318
@@ -449,16 +490,22 @@ void StructurizeCFG::analyzeLoops(RegionNode *N) {
449490}
450491
451492// / Build the condition for one edge
452- Value * StructurizeCFG::buildCondition (BranchInst *Term, unsigned Idx,
453- bool Invert) {
493+ ValueWeightPair StructurizeCFG::buildCondition (BranchInst *Term, unsigned Idx,
494+ bool Invert) {
454495 Value *Cond = Invert ? BoolFalse : BoolTrue;
496+ MaybeCondBranchWeights Weights = std::nullopt ;
497+
455498 if (Term->isConditional ()) {
456499 Cond = Term->getCondition ();
500+ Weights = CondBranchWeights::tryParse (*Term);
457501
458- if (Idx != (unsigned )Invert)
502+ if (Idx != (unsigned )Invert) {
459503 Cond = invertCondition (Cond);
504+ if (Weights)
505+ Weights = Weights->invert ();
506+ }
460507 }
461- return Cond;
508+ return { Cond, Weights} ;
462509}
463510
464511// / Analyze the predecessors of each block and build up predicates
@@ -490,8 +537,8 @@ void StructurizeCFG::gatherPredicates(RegionNode *N) {
490537 if (Visited.count (Other) && !Loops.count (Other) &&
491538 !Pred.count (Other) && !Pred.count (P)) {
492539
493- Pred[Other] = BoolFalse;
494- Pred[P] = BoolTrue;
540+ Pred[Other] = { BoolFalse, std:: nullopt } ;
541+ Pred[P] = { BoolTrue, std:: nullopt } ;
495542 continue ;
496543 }
497544 }
@@ -512,9 +559,9 @@ void StructurizeCFG::gatherPredicates(RegionNode *N) {
512559
513560 BasicBlock *Entry = R->getEntry ();
514561 if (Visited.count (Entry))
515- Pred[Entry] = BoolTrue;
562+ Pred[Entry] = { BoolTrue, std:: nullopt } ;
516563 else
517- LPred[Entry] = BoolFalse;
564+ LPred[Entry] = { BoolFalse, std:: nullopt } ;
518565 }
519566 }
520567}
@@ -578,12 +625,14 @@ void StructurizeCFG::insertConditions(bool Loops) {
578625 Dominator.addBlock (Parent);
579626
580627 Value *ParentValue = nullptr ;
581- for (std::pair<BasicBlock *, Value *> BBAndPred : Preds) {
628+ MaybeCondBranchWeights ParentWeights = std::nullopt ;
629+ for (std::pair<BasicBlock *, ValueWeightPair> BBAndPred : Preds) {
582630 BasicBlock *BB = BBAndPred.first ;
583- Value *Pred = BBAndPred.second ;
631+ Value *Pred = BBAndPred.second . first ;
584632
585633 if (BB == Parent) {
586634 ParentValue = Pred;
635+ ParentWeights = BBAndPred.second .second ;
587636 break ;
588637 }
589638 PhiInserter.AddAvailableValue (BB, Pred);
@@ -592,6 +641,7 @@ void StructurizeCFG::insertConditions(bool Loops) {
592641
593642 if (ParentValue) {
594643 Term->setCondition (ParentValue);
644+ CondBranchWeights::setMetadata (*Term, ParentWeights);
595645 } else {
596646 if (!Dominator.resultIsRememberedBlock ())
597647 PhiInserter.AddAvailableValue (Dominator.result (), Default);
@@ -607,7 +657,7 @@ void StructurizeCFG::simplifyConditions() {
607657 for (auto &I : concat<PredMap::value_type>(Predicates, LoopPreds)) {
608658 auto &Preds = I.second ;
609659 for (auto &J : Preds) {
610- auto &Cond = J.second ;
660+ auto &Cond = J.second . first ;
611661 Instruction *Inverted;
612662 if (match (Cond, m_Not (m_OneUse (m_Instruction (Inverted)))) &&
613663 !Cond->use_empty ()) {
@@ -904,9 +954,10 @@ void StructurizeCFG::setPrevNode(BasicBlock *BB) {
904954// / Does BB dominate all the predicates of Node?
905955bool StructurizeCFG::dominatesPredicates (BasicBlock *BB, RegionNode *Node) {
906956 BBPredicates &Preds = Predicates[Node->getEntry ()];
907- return llvm::all_of (Preds, [&](std::pair<BasicBlock *, Value *> Pred) {
908- return DT->dominates (BB, Pred.first );
909- });
957+ return llvm::all_of (Preds,
958+ [&](std::pair<BasicBlock *, ValueWeightPair> Pred) {
959+ return DT->dominates (BB, Pred.first );
960+ });
910961}
911962
912963// / Can we predict that this node will always be called?
@@ -918,9 +969,9 @@ bool StructurizeCFG::isPredictableTrue(RegionNode *Node) {
918969 if (!PrevNode)
919970 return true ;
920971
921- for (std::pair<BasicBlock*, Value* > Pred : Preds) {
972+ for (std::pair<BasicBlock *, ValueWeightPair > Pred : Preds) {
922973 BasicBlock *BB = Pred.first ;
923- Value *V = Pred.second ;
974+ Value *V = Pred.second . first ;
924975
925976 if (V != BoolTrue)
926977 return false ;
0 commit comments