Skip to content

Commit de54a16

Browse files
committed
Restrict to functions with smaller number of conditional branches
1 parent 3a5592a commit de54a16

File tree

1 file changed

+40
-23
lines changed

1 file changed

+40
-23
lines changed

llvm/lib/Transforms/Scalar/ConstraintElimination.cpp

Lines changed: 40 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -186,9 +186,12 @@ struct State {
186186
LoopInfo &LI;
187187
ScalarEvolution &SE;
188188
SmallVector<FactOrCheck, 64> WorkList;
189+
bool AddInductionInfoIntoHeader = false;
189190

190-
State(DominatorTree &DT, LoopInfo &LI, ScalarEvolution &SE)
191-
: DT(DT), LI(LI), SE(SE) {}
191+
State(DominatorTree &DT, LoopInfo &LI, ScalarEvolution &SE,
192+
bool AddInductionInfoIntoHeader = false)
193+
: DT(DT), LI(LI), SE(SE),
194+
AddInductionInfoIntoHeader(AddInductionInfoIntoHeader) {}
192195

193196
/// Process block \p BB and add known facts to work-list.
194197
void addInfoFor(BasicBlock &BB);
@@ -198,7 +201,7 @@ struct State {
198201
void addInfoForInductions(BasicBlock &BB);
199202

200203
void addConditionFactsIntoLoopHeader(BasicBlock &BB);
201-
204+
202205
/// Returns true if we can add a known condition from BB to its successor
203206
/// block Succ.
204207
bool canAddSuccessor(BasicBlock &BB, BasicBlock *Succ) const {
@@ -910,8 +913,9 @@ void State::addConditionFactsIntoLoopHeader(BasicBlock &BB) {
910913
if (!L || L->getHeader() != &BB)
911914
return;
912915
DomTreeNode *DTN = DT.getNode(&BB);
913-
for(PHINode &PN :L->getHeader()->phis()){
914-
if(PN.getNumIncomingValues() != 2 || PN.getParent() != &BB || !SE.isSCEVable(PN.getType()))
916+
for (PHINode &PN : L->getHeader()->phis()) {
917+
if (PN.getNumIncomingValues() != 2 || PN.getParent() != &BB ||
918+
!SE.isSCEVable(PN.getType()))
915919
continue;
916920
auto *AR = dyn_cast_or_null<SCEVAddRecExpr>(SE.getSCEV(&PN));
917921
BasicBlock *LoopPred = L->getLoopPredecessor();
@@ -927,35 +931,36 @@ void State::addConditionFactsIntoLoopHeader(BasicBlock &BB) {
927931
}
928932
auto IncUnsigned = SE.getMonotonicPredicateType(AR, CmpInst::ICMP_UGT);
929933
auto IncSigned = SE.getMonotonicPredicateType(AR, CmpInst::ICMP_SGT);
930-
931-
// Monotonically Increasing
934+
935+
// Monotonically Increasing
932936
bool MonotonicallyIncreasingUnsigned =
933-
IncUnsigned && *IncUnsigned == ScalarEvolution::MonotonicallyIncreasing;
937+
IncUnsigned && *IncUnsigned == ScalarEvolution::MonotonicallyIncreasing;
934938
bool MonotonicallyIncreasingSigned =
935939
IncSigned && *IncSigned == ScalarEvolution::MonotonicallyIncreasing;
936940
if (MonotonicallyIncreasingUnsigned)
937-
WorkList.push_back(
938-
FactOrCheck::getConditionFact(DTN, CmpInst::ICMP_UGE, &PN, StartValue));
941+
WorkList.push_back(FactOrCheck::getConditionFact(DTN, CmpInst::ICMP_UGE,
942+
&PN, StartValue));
939943
if (MonotonicallyIncreasingSigned)
940-
WorkList.push_back(
941-
FactOrCheck::getConditionFact(DTN, CmpInst::ICMP_SGE, &PN, StartValue));
944+
WorkList.push_back(FactOrCheck::getConditionFact(DTN, CmpInst::ICMP_SGE,
945+
&PN, StartValue));
942946

943947
// Monotonically Decreasing
944948
bool MonotonicallyDecreasingUnsigned =
945-
IncUnsigned && *IncUnsigned == ScalarEvolution::MonotonicallyDecreasing;
949+
IncUnsigned && *IncUnsigned == ScalarEvolution::MonotonicallyDecreasing;
946950
bool MonotonicallyDecreasingSigned =
947951
IncSigned && *IncSigned == ScalarEvolution::MonotonicallyDecreasing;
948-
if(MonotonicallyDecreasingUnsigned)
949-
WorkList.push_back(
950-
FactOrCheck::getConditionFact(DTN, CmpInst::ICMP_ULE, &PN, StartValue));
951-
if(MonotonicallyDecreasingSigned)
952-
WorkList.push_back(
953-
FactOrCheck::getConditionFact(DTN, CmpInst::ICMP_SLE, &PN, StartValue));
954-
}
952+
if (MonotonicallyDecreasingUnsigned)
953+
WorkList.push_back(FactOrCheck::getConditionFact(DTN, CmpInst::ICMP_ULE,
954+
&PN, StartValue));
955+
if (MonotonicallyDecreasingSigned)
956+
WorkList.push_back(FactOrCheck::getConditionFact(DTN, CmpInst::ICMP_SLE,
957+
&PN, StartValue));
958+
}
955959
}
956960

957961
void State::addInfoForInductions(BasicBlock &BB) {
958-
addConditionFactsIntoLoopHeader(BB);
962+
if (AddInductionInfoIntoHeader)
963+
addConditionFactsIntoLoopHeader(BB);
959964
auto *L = LI.getLoopFor(&BB);
960965
if (!L || L->getHeader() != &BB)
961966
return;
@@ -1413,7 +1418,7 @@ static std::optional<bool> checkCondition(CmpInst::Predicate Pred, Value *A,
14131418
LLVM_DEBUG(dbgs() << "Checking " << *CheckInst << "\n");
14141419

14151420
auto R = Info.getConstraintForSolving(Pred, A, B);
1416-
if (R.empty() || !R.isValid(Info)){
1421+
if (R.empty() || !R.isValid(Info)) {
14171422
LLVM_DEBUG(dbgs() << " failed to decompose condition\n");
14181423
return std::nullopt;
14191424
}
@@ -1726,6 +1731,16 @@ tryToSimplifyOverflowMath(IntrinsicInst *II, ConstraintInfo &Info,
17261731
return Changed;
17271732
}
17281733

1734+
static unsigned int getNumConditionalBranches(Function &F) {
1735+
unsigned int NumCondBranches = 0;
1736+
for (BasicBlock &BB : F) {
1737+
BranchInst *BranchInstr = dyn_cast_or_null<BranchInst>(BB.getTerminator());
1738+
if (BranchInstr && BranchInstr->isConditional())
1739+
NumCondBranches++;
1740+
}
1741+
return NumCondBranches;
1742+
}
1743+
17291744
static bool eliminateConstraints(Function &F, DominatorTree &DT, LoopInfo &LI,
17301745
ScalarEvolution &SE,
17311746
OptimizationRemarkEmitter &ORE) {
@@ -1735,7 +1750,9 @@ static bool eliminateConstraints(Function &F, DominatorTree &DT, LoopInfo &LI,
17351750
for (Value &Arg : F.args())
17361751
FunctionArgs.push_back(&Arg);
17371752
ConstraintInfo Info(F.getDataLayout(), FunctionArgs);
1738-
State S(DT, LI, SE);
1753+
unsigned int NumCondBranches = getNumConditionalBranches(F);
1754+
State S(DT, LI, SE,
1755+
/* AddInductionInfoIntoHeader= */ NumCondBranches < MaxRows / 5);
17391756
std::unique_ptr<Module> ReproducerModule(
17401757
DumpReproducers ? new Module(F.getName(), F.getContext()) : nullptr);
17411758

0 commit comments

Comments
 (0)