Skip to content

Commit cc06bc0

Browse files
committed
[DFAJumpThreading] Preserve BFI and BPI during unfolding selects
1 parent ed113e7 commit cc06bc0

File tree

1 file changed

+56
-10
lines changed

1 file changed

+56
-10
lines changed

llvm/lib/Transforms/Scalar/DFAJumpThreading.cpp

Lines changed: 56 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -62,6 +62,7 @@
6262
#include "llvm/ADT/DenseMap.h"
6363
#include "llvm/ADT/Statistic.h"
6464
#include "llvm/Analysis/AssumptionCache.h"
65+
#include "llvm/Analysis/BranchProbabilityInfo.h"
6566
#include "llvm/Analysis/CodeMetrics.h"
6667
#include "llvm/Analysis/DomTreeUpdater.h"
6768
#include "llvm/Analysis/LoopInfo.h"
@@ -141,18 +142,21 @@ class SelectInstToUnfold {
141142
explicit operator bool() const { return SI && SIUse; }
142143
};
143144

144-
void unfold(DomTreeUpdater *DTU, LoopInfo *LI, SelectInstToUnfold SIToUnfold,
145+
void unfold(DomTreeUpdater *DTU, LoopInfo *LI, BlockFrequencyInfo *BFI,
146+
BranchProbabilityInfo *BPI, SelectInstToUnfold SIToUnfold,
145147
std::vector<SelectInstToUnfold> *NewSIsToUnfold,
146148
std::vector<BasicBlock *> *NewBBs);
147149

148150
class DFAJumpThreading {
149151
public:
150152
DFAJumpThreading(AssumptionCache *AC, DominatorTree *DT, LoopInfo *LI,
153+
BlockFrequencyInfo *BFI, BranchProbabilityInfo *BPI,
151154
TargetTransformInfo *TTI, OptimizationRemarkEmitter *ORE)
152-
: AC(AC), DT(DT), LI(LI), TTI(TTI), ORE(ORE) {}
155+
: AC(AC), DT(DT), LI(LI), BFI(BFI), BPI(BPI), TTI(TTI), ORE(ORE) {}
153156

154157
bool run(Function &F);
155158
bool LoopInfoBroken;
159+
bool BFIBPIBroken;
156160

157161
private:
158162
void
@@ -167,7 +171,7 @@ class DFAJumpThreading {
167171

168172
std::vector<SelectInstToUnfold> NewSIsToUnfold;
169173
std::vector<BasicBlock *> NewBBs;
170-
unfold(&DTU, LI, SIToUnfold, &NewSIsToUnfold, &NewBBs);
174+
unfold(&DTU, LI, BFI, BPI, SIToUnfold, &NewSIsToUnfold, &NewBBs);
171175

172176
// Put newly discovered select instructions into the work list.
173177
llvm::append_range(Stack, NewSIsToUnfold);
@@ -177,6 +181,8 @@ class DFAJumpThreading {
177181
AssumptionCache *AC;
178182
DominatorTree *DT;
179183
LoopInfo *LI;
184+
BlockFrequencyInfo *BFI;
185+
BranchProbabilityInfo *BPI;
180186
TargetTransformInfo *TTI;
181187
OptimizationRemarkEmitter *ORE;
182188
};
@@ -192,17 +198,32 @@ namespace {
192198
/// created basic blocks into \p NewBBs.
193199
///
194200
/// TODO: merge it with CodeGenPrepare::optimizeSelectInst() if possible.
195-
void unfold(DomTreeUpdater *DTU, LoopInfo *LI, SelectInstToUnfold SIToUnfold,
201+
void unfold(DomTreeUpdater *DTU, LoopInfo *LI, BlockFrequencyInfo *BFI,
202+
BranchProbabilityInfo *BPI, SelectInstToUnfold SIToUnfold,
196203
std::vector<SelectInstToUnfold> *NewSIsToUnfold,
197204
std::vector<BasicBlock *> *NewBBs) {
198205
SelectInst *SI = SIToUnfold.getInst();
199206
PHINode *SIUse = SIToUnfold.getUse();
200207
assert(SI->hasOneUse());
201208
// The select may come indirectly, instead of from where it is defined.
202209
BasicBlock *StartBlock = SIUse->getIncomingBlock(*SI->use_begin());
203-
BranchInst *StartBlockTerm =
204-
dyn_cast<BranchInst>(StartBlock->getTerminator());
205-
assert(StartBlockTerm);
210+
BranchInst *StartBlockTerm = cast<BranchInst>(StartBlock->getTerminator());
211+
212+
uint64_t TrueWeight = 1;
213+
uint64_t FalseWeight = 1;
214+
// Copy probabilities from 'SI' to the created conditional branch.
215+
SmallVector<BranchProbability, 2> SIProbs;
216+
if (extractBranchWeights(*SI, TrueWeight, FalseWeight) &&
217+
(TrueWeight + FalseWeight) != 0) {
218+
SIProbs.emplace_back(BranchProbability::getBranchProbability(
219+
TrueWeight, TrueWeight + FalseWeight));
220+
SIProbs.emplace_back(BranchProbability::getBranchProbability(
221+
FalseWeight, TrueWeight + FalseWeight));
222+
}
223+
if ((TrueWeight + FalseWeight) == 0)
224+
TrueWeight = FalseWeight = 1;
225+
auto FalseProb = BranchProbability::getBranchProbability(
226+
FalseWeight, TrueWeight + FalseWeight);
206227

207228
if (StartBlockTerm->isUnconditional()) {
208229
BasicBlock *EndBlock = StartBlock->getUniqueSuccessor();
@@ -263,6 +284,13 @@ void unfold(DomTreeUpdater *DTU, LoopInfo *LI, SelectInstToUnfold SIToUnfold,
263284
BranchInst::Create(EndBlock, NewBlock, SI->getCondition(), StartBlock);
264285
DTU->applyUpdates({{DominatorTree::Insert, StartBlock, EndBlock},
265286
{DominatorTree::Insert, StartBlock, NewBlock}});
287+
288+
// Update BPI if exists.
289+
if (BPI && !SIProbs.empty())
290+
BPI->setEdgeProbability(StartBlock, SIProbs);
291+
// Update the block frequency of both NewBlock.
292+
if (BFI)
293+
BFI->setBlockFreq(NewBlock, BFI->getBlockFreq(StartBlock) * FalseProb);
266294
} else {
267295
BasicBlock *EndBlock = SIUse->getParent();
268296
BasicBlock *NewBlockT = BasicBlock::Create(
@@ -336,6 +364,17 @@ void unfold(DomTreeUpdater *DTU, LoopInfo *LI, SelectInstToUnfold SIToUnfold,
336364
StartBlockTerm->setSuccessor(SuccNum, NewBlockT);
337365
DTU->applyUpdates({{DominatorTree::Delete, StartBlock, EndBlock},
338366
{DominatorTree::Insert, StartBlock, NewBlockT}});
367+
// Update BPI if exists.
368+
if (BPI && !SIProbs.empty())
369+
BPI->setEdgeProbability(NewBlockT, SIProbs);
370+
// Update the block frequency of both NewBB and EndBB.
371+
if (BFI) {
372+
assert(BPI && "BPI should be valid if BFI exists");
373+
auto NewBlockTFreq = BFI->getBlockFreq(StartBlock) *
374+
BPI->getEdgeProbability(StartBlock, SuccNum);
375+
BFI->setBlockFreq(NewBlockT, NewBlockTFreq);
376+
BFI->setBlockFreq(NewBlockF, NewBlockTFreq * FalseProb);
377+
}
339378
}
340379

341380
// Preserve loop info
@@ -994,6 +1033,7 @@ struct TransformDFA {
9941033
SmallPtrSet<BasicBlock *, 16> BlocksToClean;
9951034
BlocksToClean.insert_range(successors(SwitchBlock));
9961035

1036+
// TODO: Preserve BFI/BPI during creating exit paths.
9971037
{
9981038
DomTreeUpdater DTU(*DT, DomTreeUpdater::UpdateStrategy::Lazy);
9991039
for (const ThreadingPath &TPath : SwitchPaths->getThreadingPaths()) {
@@ -1378,7 +1418,7 @@ bool DFAJumpThreading::run(Function &F) {
13781418

13791419
SmallVector<AllSwitchPaths, 2> ThreadableLoops;
13801420
bool MadeChanges = false;
1381-
LoopInfoBroken = false;
1421+
LoopInfoBroken = BFIBPIBroken = false;
13821422

13831423
for (BasicBlock &BB : F) {
13841424
auto *SI = dyn_cast<SwitchInst>(BB.getTerminator());
@@ -1431,7 +1471,7 @@ bool DFAJumpThreading::run(Function &F) {
14311471
for (AllSwitchPaths SwitchPaths : ThreadableLoops) {
14321472
TransformDFA Transform(&SwitchPaths, DT, AC, TTI, ORE, EphValues);
14331473
if (Transform.run())
1434-
MadeChanges = LoopInfoBroken = true;
1474+
MadeChanges = LoopInfoBroken = BFIBPIBroken = true;
14351475
}
14361476

14371477
#ifdef EXPENSIVE_CHECKS
@@ -1450,15 +1490,21 @@ PreservedAnalyses DFAJumpThreadingPass::run(Function &F,
14501490
AssumptionCache &AC = AM.getResult<AssumptionAnalysis>(F);
14511491
DominatorTree &DT = AM.getResult<DominatorTreeAnalysis>(F);
14521492
LoopInfo &LI = AM.getResult<LoopAnalysis>(F);
1493+
BlockFrequencyInfo *BFI = AM.getCachedResult<BlockFrequencyAnalysis>(F);
1494+
BranchProbabilityInfo *BPI = AM.getCachedResult<BranchProbabilityAnalysis>(F);
14531495
TargetTransformInfo &TTI = AM.getResult<TargetIRAnalysis>(F);
14541496
OptimizationRemarkEmitter ORE(&F);
1455-
DFAJumpThreading ThreadImpl(&AC, &DT, &LI, &TTI, &ORE);
1497+
DFAJumpThreading ThreadImpl(&AC, &DT, &LI, BFI, BPI, &TTI, &ORE);
14561498
if (!ThreadImpl.run(F))
14571499
return PreservedAnalyses::all();
14581500

14591501
PreservedAnalyses PA;
14601502
PA.preserve<DominatorTreeAnalysis>();
14611503
if (!ThreadImpl.LoopInfoBroken)
14621504
PA.preserve<LoopAnalysis>();
1505+
if (!ThreadImpl.BFIBPIBroken) {
1506+
PA.preserve<BranchProbabilityAnalysis>();
1507+
PA.preserve<BlockFrequencyAnalysis>();
1508+
}
14631509
return PA;
14641510
}

0 commit comments

Comments
 (0)