Skip to content

Commit 3feaafd

Browse files
committed
[PredicateInfo] Handle switch comprehensively
Now we can handle default-dst and multi-cases dest of switch for PredicateInfo, via using a constant range to model such scenario.
1 parent 7bf89cc commit 3feaafd

File tree

3 files changed

+175
-53
lines changed

3 files changed

+175
-53
lines changed

llvm/include/llvm/Transforms/Utils/PredicateInfo.h

Lines changed: 14 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -102,6 +102,8 @@ class PredicateBase {
102102

103103
/// Fetch condition in the form of PredicateConstraint, if possible.
104104
LLVM_ABI std::optional<PredicateConstraint> getConstraint() const;
105+
/// Fetch condition in the form of a ConstantRange, if possible.
106+
LLVM_ABI std::optional<ConstantRange> getRangeConstraint() const;
105107

106108
protected:
107109
PredicateBase(PredicateType PT, Value *Op, Value *Condition)
@@ -157,18 +159,22 @@ class PredicateBranch : public PredicateWithEdge {
157159

158160
class PredicateSwitch : public PredicateWithEdge {
159161
public:
160-
Value *CaseValue;
161-
// This is the switch instruction.
162-
SwitchInst *Switch;
162+
using CaseValuesVec = SmallVector<ConstantInt *, 2>;
163+
CaseValuesVec CaseValues;
164+
bool IsDefault;
163165
PredicateSwitch(Value *Op, BasicBlock *SwitchBB, BasicBlock *TargetBB,
164-
Value *CaseValue, SwitchInst *SI)
166+
ArrayRef<ConstantInt *> CaseValues, SwitchInst *SI,
167+
bool IsDefault)
165168
: PredicateWithEdge(PT_Switch, Op, SwitchBB, TargetBB,
166169
SI->getCondition()),
167-
CaseValue(CaseValue), Switch(SI) {}
170+
CaseValues(CaseValues), IsDefault(IsDefault) {}
171+
PredicateSwitch(Value *Op, BasicBlock *SwitchBB, BasicBlock *TargetBB,
172+
CaseValuesVec &&CaseValues, SwitchInst *SI, bool IsDefault)
173+
: PredicateWithEdge(PT_Switch, Op, SwitchBB, TargetBB,
174+
SI->getCondition()),
175+
CaseValues(CaseValues), IsDefault(IsDefault) {}
168176
PredicateSwitch() = delete;
169-
static bool classof(const PredicateBase *PB) {
170-
return PB->Type == PT_Switch;
171-
}
177+
static bool classof(const PredicateBase *PB) { return PB->Type == PT_Switch; }
172178
};
173179

174180
/// Encapsulates PredicateInfo, including all data associated with memory

llvm/lib/Transforms/Utils/PredicateInfo.cpp

Lines changed: 107 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -14,8 +14,10 @@
1414
#include "llvm/ADT/DenseMap.h"
1515
#include "llvm/ADT/STLExtras.h"
1616
#include "llvm/ADT/SmallPtrSet.h"
17+
#include "llvm/ADT/StringExtras.h"
1718
#include "llvm/Analysis/AssumptionCache.h"
1819
#include "llvm/IR/AssemblyAnnotationWriter.h"
20+
#include "llvm/IR/Constants.h"
1921
#include "llvm/IR/Dominators.h"
2022
#include "llvm/IR/IRBuilder.h"
2123
#include "llvm/IR/InstIterator.h"
@@ -442,6 +444,7 @@ void PredicateInfoBuilder::processBranch(
442444
}
443445
}
444446
}
447+
445448
// Process a block terminating switch, and place relevant operations to be
446449
// renamed into OpsToRename.
447450
void PredicateInfoBuilder::processSwitch(
@@ -450,21 +453,41 @@ void PredicateInfoBuilder::processSwitch(
450453
Value *Op = SI->getCondition();
451454
if ((!isa<Instruction>(Op) && !isa<Argument>(Op)) || Op->hasOneUse())
452455
return;
456+
using CaseValuesVec = PredicateSwitch::CaseValuesVec;
457+
458+
BasicBlock *DefaultDest = SI->getDefaultDest();
459+
// Remember all cases for PT_Switch related to the default dest.
460+
CaseValuesVec AllCases;
461+
AllCases.reserve(SI->getNumCases());
453462

454-
// Remember how many outgoing edges there are to every successor.
455-
SmallDenseMap<BasicBlock *, unsigned, 16> SwitchEdges;
456-
for (BasicBlock *TargetBlock : successors(BranchBB))
457-
++SwitchEdges[TargetBlock];
463+
// For each successor, remember all its related case values.
464+
SmallDenseMap<BasicBlock *, CaseValuesVec, 16> SwitchEdges;
458465

459-
// Now propagate info for each case value
460466
for (auto C : SI->cases()) {
461467
BasicBlock *TargetBlock = C.getCaseSuccessor();
462-
if (SwitchEdges.lookup(TargetBlock) == 1) {
463-
PredicateSwitch *PS = new (Allocator) PredicateSwitch(
464-
Op, SI->getParent(), TargetBlock, C.getCaseValue(), SI);
465-
addInfoFor(OpsToRename, Op, PS);
466-
}
468+
/// TODO: Replace this if with an assertion if we can guarantee that
469+
/// this function must be called after SimplifyCFG, as a canonical switch
470+
/// should not have case dest being the default dest.
471+
if (TargetBlock == DefaultDest)
472+
continue;
473+
// Only collect real case values
474+
ConstantInt *CaseValue = C.getCaseValue();
475+
AllCases.push_back(CaseValue);
476+
SwitchEdges[TargetBlock].push_back(CaseValue);
467477
}
478+
479+
// Now propagate info for each case successor
480+
for (auto *CaseSucc : SwitchEdges.keys()) {
481+
auto &CaseValues = SwitchEdges.at(CaseSucc);
482+
PredicateSwitch *PS = new (Allocator) PredicateSwitch(
483+
Op, SI->getParent(), CaseSucc, std::move(CaseValues), SI, false);
484+
addInfoFor(OpsToRename, Op, PS);
485+
}
486+
487+
// Finally, propagate info for the default case
488+
PredicateSwitch *PS = new (Allocator) PredicateSwitch(
489+
Op, SI->getParent(), DefaultDest, std::move(AllCases), SI, true);
490+
addInfoFor(OpsToRename, Op, PS);
468491
}
469492

470493
// Build predicate info for our function
@@ -500,8 +523,8 @@ void PredicateInfoBuilder::buildPredicateInfo() {
500523
// Given the renaming stack, make all the operands currently on the stack real
501524
// by inserting them into the IR. Return the last operation's value.
502525
Value *PredicateInfoBuilder::materializeStack(unsigned int &Counter,
503-
ValueDFSStack &RenameStack,
504-
Value *OrigOp) {
526+
ValueDFSStack &RenameStack,
527+
Value *OrigOp) {
505528
// Find the first thing we have to materialize
506529
auto RevIter = RenameStack.rbegin();
507530
for (; RevIter != RenameStack.rend(); ++RevIter)
@@ -601,7 +624,8 @@ void PredicateInfoBuilder::renameUses(SmallVectorImpl<Value *> &OpsToRename) {
601624
// block, and handle it specially. We know that it goes last, and only
602625
// dominate phi uses.
603626
auto BlockEdge = getBlockEdge(PossibleCopy);
604-
if (!BlockEdge.second->getSinglePredecessor()) {
627+
// We use unique predecessor to identify the mult-cases dest in switch
628+
if (!BlockEdge.second->getUniquePredecessor()) {
605629
VD.LocalNum = LN_Last;
606630
auto *DomNode = DT.getNode(BlockEdge.first);
607631
if (DomNode) {
@@ -759,8 +783,63 @@ std::optional<PredicateConstraint> PredicateBase::getConstraint() const {
759783
// TODO: Make this an assertion once RenamedOp is fully accurate.
760784
return std::nullopt;
761785
}
786+
const auto &PS = *cast<PredicateSwitch>(this);
787+
unsigned NumCases = PS.CaseValues.size();
788+
assert(NumCases != 0 && "PT_Switch with no cases is invalid");
789+
// PT_Switch with >1 cases is too complex to derive a PredicateConstraint.
790+
if (NumCases > 1)
791+
return std::nullopt;
792+
// If we have a single case, we can derive a predicate constraint.
793+
return {
794+
{PS.IsDefault ? CmpInst::ICMP_NE : CmpInst::ICMP_EQ, PS.CaseValues[0]}};
795+
}
796+
llvm_unreachable("Unknown predicate type");
797+
}
762798

763-
return {{CmpInst::ICMP_EQ, cast<PredicateSwitch>(this)->CaseValue}};
799+
std::optional<ConstantRange> PredicateBase::getRangeConstraint() const {
800+
switch (Type) {
801+
case PT_Assume:
802+
case PT_Branch: {
803+
// For PT_Assume/PT_Branch, we derive the condition constant range from
804+
// its predicate constraint.
805+
const std::optional<PredicateConstraint> &Constraint = getConstraint();
806+
if (!Constraint)
807+
return std::nullopt;
808+
CmpInst::Predicate Pred = Constraint->Predicate;
809+
Value *OtherOp = Constraint->OtherOp;
810+
const APInt *IntOp;
811+
// If the other operand is not a constant integer, we can't derive a
812+
// constant range.
813+
if (!match(OtherOp, m_APInt(IntOp)))
814+
return std::nullopt;
815+
return {ConstantRange::makeExactICmpRegion(Pred, *IntOp)};
816+
}
817+
case PT_Switch:
818+
// For PT_Switch, we directly derive the constant range from its case
819+
// values.
820+
if (Condition != RenamedOp) {
821+
// TODO: Make this an assertion once RenamedOp is fully accurate.
822+
return std::nullopt;
823+
}
824+
825+
const auto &PS = *cast<PredicateSwitch>(this);
826+
assert(!PS.CaseValues.empty() && "SwitchInfo with no cases is invalid");
827+
828+
unsigned BitWidth = PS.Condition->getType()->getScalarSizeInBits();
829+
830+
// For case values, CR = emptyset ∪ {case1, case2,..., caseN}
831+
// For default, CR = fullset ∩ ~{case1} ∩ ~{case2} ∩ ... ∩ ~{caseN}
832+
bool IsDefault = PS.IsDefault;
833+
ConstantRange CR = IsDefault ? ConstantRange::getFull(BitWidth)
834+
: ConstantRange::getEmpty(BitWidth);
835+
for (ConstantInt *Case : PS.CaseValues) {
836+
assert(Case && "CaseValue in switch should not be null");
837+
CR = IsDefault
838+
? CR.intersectWith(ConstantRange(Case->getValue()).inverse())
839+
: CR.unionWith(Case->getValue());
840+
}
841+
842+
return {CR};
764843
}
765844
llvm_unreachable("Unknown predicate type");
766845
}
@@ -818,8 +897,20 @@ class PredicateInfoAnnotatedWriter : public AssemblyAnnotationWriter {
818897
PB->To->printAsOperand(OS);
819898
OS << "]";
820899
} else if (const auto *PS = dyn_cast<PredicateSwitch>(PI)) {
821-
OS << "; switch predicate info { CaseValue: " << *PS->CaseValue
822-
<< " Edge: [";
900+
OS << "; switch predicate info { ";
901+
if (PS->IsDefault) {
902+
OS << "Case: default";
903+
} else if (PS->CaseValues.size() == 1) {
904+
OS << "CaseValue: " << *PS->CaseValues[0];
905+
} else {
906+
auto CaseValues =
907+
llvm::map_range(PS->CaseValues, [](ConstantInt *Case) {
908+
return std::to_string(Case->getSExtValue());
909+
});
910+
OS << "CaseValues: " << *PS->Condition->getType() << " [ "
911+
<< join(CaseValues, ", ") << " ]";
912+
}
913+
OS << " Edge: [";
823914
PS->From->printAsOperand(OS);
824915
OS << ",";
825916
PS->To->printAsOperand(OS);

llvm/lib/Transforms/Utils/SCCPSolver.cpp

Lines changed: 54 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -2021,33 +2021,13 @@ void SCCPInstVisitor::handleCallArguments(CallBase &CB) {
20212021

20222022
void SCCPInstVisitor::handlePredicate(Instruction *I, Value *CopyOf,
20232023
const PredicateBase *PI) {
2024+
const std::optional<ConstantRange> &RangeConstraint =
2025+
PI->getRangeConstraint();
20242026
ValueLatticeElement CopyOfVal = getValueState(CopyOf);
2025-
const std::optional<PredicateConstraint> &Constraint = PI->getConstraint();
2026-
if (!Constraint) {
2027-
mergeInValue(ValueState[I], I, CopyOfVal);
2028-
return;
2029-
}
2030-
2031-
CmpInst::Predicate Pred = Constraint->Predicate;
2032-
Value *OtherOp = Constraint->OtherOp;
2033-
2034-
// Wait until OtherOp is resolved.
2035-
if (getValueState(OtherOp).isUnknown()) {
2036-
addAdditionalUser(OtherOp, I);
2037-
return;
2038-
}
2039-
2040-
ValueLatticeElement CondVal = getValueState(OtherOp);
2041-
ValueLatticeElement &IV = ValueState[I];
2042-
if (CondVal.isConstantRange() || CopyOfVal.isConstantRange()) {
2043-
auto ImposedCR =
2044-
ConstantRange::getFull(DL.getTypeSizeInBits(CopyOf->getType()));
2045-
2046-
// Get the range imposed by the condition.
2047-
if (CondVal.isConstantRange())
2048-
ImposedCR = ConstantRange::makeAllowedICmpRegion(
2049-
Pred, CondVal.getConstantRange());
20502027

2028+
auto MergeInValueWithImposedCR = [this, I, CopyOfVal,
2029+
CopyOf](ValueLatticeElement &IV,
2030+
ConstantRange ImposedCR) {
20512031
// Combine range info for the original value with the new range from the
20522032
// condition.
20532033
auto CopyOfCR = CopyOfVal.asConstantRange(CopyOf->getType(),
@@ -2067,18 +2047,63 @@ void SCCPInstVisitor::handlePredicate(Instruction *I, Value *CopyOf,
20672047
// unless we have conditions that are always true/false (e.g. icmp ule
20682048
// i32, %a, i32_max). For the latter overdefined/empty range will be
20692049
// inferred, but the branch will get folded accordingly anyways.
2070-
addAdditionalUser(OtherOp, I);
20712050
mergeInValue(
20722051
IV, I, ValueLatticeElement::getRange(NewCR, /*MayIncludeUndef*/ false));
2052+
};
2053+
2054+
if (RangeConstraint) {
2055+
// If we can derive a constant range directly from the predicate info,
2056+
// simply merge it into the lattice value.
2057+
// In such case, the relevant operands must be constants, and thus we do not
2058+
// need addAdditionalUser for such operands.
2059+
MergeInValueWithImposedCR(ValueState[I], *RangeConstraint);
2060+
return;
2061+
}
2062+
2063+
// If we can't simply get the constant range directly from the predicate info,
2064+
// then fallback to PredicateConstraint and let SCCPSolver resolve the
2065+
// possible Imposed CR.
2066+
2067+
const std::optional<PredicateConstraint> &Constraint = PI->getConstraint();
2068+
if (!Constraint) {
2069+
mergeInValue(ValueState[I], I, CopyOfVal);
2070+
return;
2071+
}
2072+
2073+
CmpInst::Predicate Pred = Constraint->Predicate;
2074+
Value *OtherOp = Constraint->OtherOp;
2075+
2076+
// Wait until OtherOp is resolved.
2077+
if (getValueState(OtherOp).isUnknown()) {
2078+
addAdditionalUser(OtherOp, I);
20732079
return;
2074-
} else if (Pred == CmpInst::ICMP_EQ &&
2075-
(CondVal.isConstant() || CondVal.isNotConstant())) {
2080+
}
2081+
2082+
ValueLatticeElement CondVal = getValueState(OtherOp);
2083+
ValueLatticeElement &IV = ValueState[I];
2084+
if (CondVal.isConstantRange() || CopyOfVal.isConstantRange()) {
2085+
// Get the range imposed by the condition.
2086+
auto ImposedCR =
2087+
CondVal.isConstantRange()
2088+
? ConstantRange::makeAllowedICmpRegion(Pred,
2089+
CondVal.getConstantRange())
2090+
: ConstantRange::getFull(DL.getTypeSizeInBits(CopyOf->getType()));
2091+
2092+
addAdditionalUser(OtherOp, I);
2093+
MergeInValueWithImposedCR(IV, ImposedCR);
2094+
return;
2095+
}
2096+
2097+
if (Pred == CmpInst::ICMP_EQ &&
2098+
(CondVal.isConstant() || CondVal.isNotConstant())) {
20762099
// For non-integer values or integer constant expressions, only
20772100
// propagate equal constants or not-constants.
20782101
addAdditionalUser(OtherOp, I);
20792102
mergeInValue(IV, I, CondVal);
20802103
return;
2081-
} else if (Pred == CmpInst::ICMP_NE && CondVal.isConstant()) {
2104+
}
2105+
2106+
if (Pred == CmpInst::ICMP_NE && CondVal.isConstant()) {
20822107
// Propagate inequalities.
20832108
addAdditionalUser(OtherOp, I);
20842109
mergeInValue(IV, I, ValueLatticeElement::getNot(CondVal.getConstant()));

0 commit comments

Comments
 (0)