1414#include " llvm/ADT/DenseMap.h"
1515#include " llvm/ADT/STLExtras.h"
1616#include " llvm/ADT/SmallPtrSet.h"
17+ #include " llvm/ADT/StringExtras.h"
1718#include " llvm/Analysis/AssumptionCache.h"
1819#include " llvm/IR/AssemblyAnnotationWriter.h"
20+ #include " llvm/IR/Constants.h"
1921#include " llvm/IR/Dominators.h"
2022#include " llvm/IR/IRBuilder.h"
2123#include " llvm/IR/InstIterator.h"
@@ -442,6 +444,7 @@ void PredicateInfoBuilder::processBranch(
442444 }
443445 }
444446}
447+
445448// Process a block terminating switch, and place relevant operations to be
446449// renamed into OpsToRename.
447450void PredicateInfoBuilder::processSwitch (
@@ -450,21 +453,41 @@ void PredicateInfoBuilder::processSwitch(
450453 Value *Op = SI->getCondition ();
451454 if ((!isa<Instruction>(Op) && !isa<Argument>(Op)) || Op->hasOneUse ())
452455 return ;
456+ using CaseValuesVec = PredicateSwitch::CaseValuesVec;
457+
458+ BasicBlock *DefaultDest = SI->getDefaultDest ();
459+ // Remember all cases for PT_Switch related to the default dest.
460+ CaseValuesVec AllCases;
461+ AllCases.reserve (SI->getNumCases ());
453462
454- // Remember how many outgoing edges there are to every successor.
455- SmallDenseMap<BasicBlock *, unsigned , 16 > SwitchEdges;
456- for (BasicBlock *TargetBlock : successors (BranchBB))
457- ++SwitchEdges[TargetBlock];
463+ // For each successor, remember all its related case values.
464+ SmallDenseMap<BasicBlock *, CaseValuesVec, 16 > SwitchEdges;
458465
459- // Now propagate info for each case value
460466 for (auto C : SI->cases ()) {
461467 BasicBlock *TargetBlock = C.getCaseSuccessor ();
462- if (SwitchEdges.lookup (TargetBlock) == 1 ) {
463- PredicateSwitch *PS = new (Allocator) PredicateSwitch (
464- Op, SI->getParent (), TargetBlock, C.getCaseValue (), SI);
465- addInfoFor (OpsToRename, Op, PS);
466- }
468+ // / TODO: Replace this if with an assertion if we can guarantee that
469+ // / this function must be called after SimplifyCFG, as a canonical switch
470+ // / should not have case dest being the default dest.
471+ if (TargetBlock == DefaultDest)
472+ continue ;
473+ // Only collect real case values
474+ ConstantInt *CaseValue = C.getCaseValue ();
475+ AllCases.push_back (CaseValue);
476+ SwitchEdges[TargetBlock].push_back (CaseValue);
467477 }
478+
479+ // Now propagate info for each case successor
480+ for (auto *CaseSucc : SwitchEdges.keys ()) {
481+ auto &CaseValues = SwitchEdges.at (CaseSucc);
482+ PredicateSwitch *PS = new (Allocator) PredicateSwitch (
483+ Op, SI->getParent (), CaseSucc, std::move (CaseValues), SI, false );
484+ addInfoFor (OpsToRename, Op, PS);
485+ }
486+
487+ // Finally, propagate info for the default case
488+ PredicateSwitch *PS = new (Allocator) PredicateSwitch (
489+ Op, SI->getParent (), DefaultDest, std::move (AllCases), SI, true );
490+ addInfoFor (OpsToRename, Op, PS);
468491}
469492
470493// Build predicate info for our function
@@ -500,8 +523,8 @@ void PredicateInfoBuilder::buildPredicateInfo() {
500523// Given the renaming stack, make all the operands currently on the stack real
501524// by inserting them into the IR. Return the last operation's value.
502525Value *PredicateInfoBuilder::materializeStack (unsigned int &Counter,
503- ValueDFSStack &RenameStack,
504- Value *OrigOp) {
526+ ValueDFSStack &RenameStack,
527+ Value *OrigOp) {
505528 // Find the first thing we have to materialize
506529 auto RevIter = RenameStack.rbegin ();
507530 for (; RevIter != RenameStack.rend (); ++RevIter)
@@ -601,7 +624,8 @@ void PredicateInfoBuilder::renameUses(SmallVectorImpl<Value *> &OpsToRename) {
601624 // block, and handle it specially. We know that it goes last, and only
602625 // dominate phi uses.
603626 auto BlockEdge = getBlockEdge (PossibleCopy);
604- if (!BlockEdge.second ->getSinglePredecessor ()) {
627+ // We use unique predecessor to identify the mult-cases dest in switch
628+ if (!BlockEdge.second ->getUniquePredecessor ()) {
605629 VD.LocalNum = LN_Last;
606630 auto *DomNode = DT.getNode (BlockEdge.first );
607631 if (DomNode) {
@@ -759,8 +783,63 @@ std::optional<PredicateConstraint> PredicateBase::getConstraint() const {
759783 // TODO: Make this an assertion once RenamedOp is fully accurate.
760784 return std::nullopt ;
761785 }
786+ const auto &PS = *cast<PredicateSwitch>(this );
787+ unsigned NumCases = PS.CaseValues .size ();
788+ assert (NumCases != 0 && " PT_Switch with no cases is invalid" );
789+ // PT_Switch with >1 cases is too complex to derive a PredicateConstraint.
790+ if (NumCases > 1 )
791+ return std::nullopt ;
792+ // If we have a single case, we can derive a predicate constraint.
793+ return {
794+ {PS.IsDefault ? CmpInst::ICMP_NE : CmpInst::ICMP_EQ, PS.CaseValues [0 ]}};
795+ }
796+ llvm_unreachable (" Unknown predicate type" );
797+ }
762798
763- return {{CmpInst::ICMP_EQ, cast<PredicateSwitch>(this )->CaseValue }};
799+ std::optional<ConstantRange> PredicateBase::getRangeConstraint () const {
800+ switch (Type) {
801+ case PT_Assume:
802+ case PT_Branch: {
803+ // For PT_Assume/PT_Branch, we derive the condition constant range from
804+ // its predicate constraint.
805+ const std::optional<PredicateConstraint> &Constraint = getConstraint ();
806+ if (!Constraint)
807+ return std::nullopt ;
808+ CmpInst::Predicate Pred = Constraint->Predicate ;
809+ Value *OtherOp = Constraint->OtherOp ;
810+ const APInt *IntOp;
811+ // If the other operand is not a constant integer, we can't derive a
812+ // constant range.
813+ if (!match (OtherOp, m_APInt (IntOp)))
814+ return std::nullopt ;
815+ return {ConstantRange::makeExactICmpRegion (Pred, *IntOp)};
816+ }
817+ case PT_Switch:
818+ // For PT_Switch, we directly derive the constant range from its case
819+ // values.
820+ if (Condition != RenamedOp) {
821+ // TODO: Make this an assertion once RenamedOp is fully accurate.
822+ return std::nullopt ;
823+ }
824+
825+ const auto &PS = *cast<PredicateSwitch>(this );
826+ assert (!PS.CaseValues .empty () && " SwitchInfo with no cases is invalid" );
827+
828+ unsigned BitWidth = PS.Condition ->getType ()->getScalarSizeInBits ();
829+
830+ // For case values, CR = emptyset ∪ {case1, case2,..., caseN}
831+ // For default, CR = fullset ∩ ~{case1} ∩ ~{case2} ∩ ... ∩ ~{caseN}
832+ bool IsDefault = PS.IsDefault ;
833+ ConstantRange CR = IsDefault ? ConstantRange::getFull (BitWidth)
834+ : ConstantRange::getEmpty (BitWidth);
835+ for (ConstantInt *Case : PS.CaseValues ) {
836+ assert (Case && " CaseValue in switch should not be null" );
837+ CR = IsDefault
838+ ? CR.intersectWith (ConstantRange (Case->getValue ()).inverse ())
839+ : CR.unionWith (Case->getValue ());
840+ }
841+
842+ return {CR};
764843 }
765844 llvm_unreachable (" Unknown predicate type" );
766845}
@@ -818,8 +897,20 @@ class PredicateInfoAnnotatedWriter : public AssemblyAnnotationWriter {
818897 PB->To ->printAsOperand (OS);
819898 OS << " ]" ;
820899 } else if (const auto *PS = dyn_cast<PredicateSwitch>(PI)) {
821- OS << " ; switch predicate info { CaseValue: " << *PS->CaseValue
822- << " Edge: [" ;
900+ OS << " ; switch predicate info { " ;
901+ if (PS->IsDefault ) {
902+ OS << " Case: default" ;
903+ } else if (PS->CaseValues .size () == 1 ) {
904+ OS << " CaseValue: " << *PS->CaseValues [0 ];
905+ } else {
906+ auto CaseValues =
907+ llvm::map_range (PS->CaseValues , [](ConstantInt *Case) {
908+ return std::to_string (Case->getSExtValue ());
909+ });
910+ OS << " CaseValues: " << *PS->Condition ->getType () << " [ "
911+ << join (CaseValues, " , " ) << " ]" ;
912+ }
913+ OS << " Edge: [" ;
823914 PS->From ->printAsOperand (OS);
824915 OS << " ," ;
825916 PS->To ->printAsOperand (OS);
0 commit comments