@@ -5734,15 +5734,66 @@ bool SimplifyCFGOpt::simplifyUnreachable(UnreachableInst *UI) {
5734
5734
return Changed;
5735
5735
}
5736
5736
5737
- static bool casesAreContiguous (SmallVectorImpl<ConstantInt *> &Cases) {
5737
+ struct ContiguousCasesResult {
5738
+ ConstantInt *Min;
5739
+ ConstantInt *Max;
5740
+ BasicBlock *Dest;
5741
+ BasicBlock *OtherDest;
5742
+ SmallVectorImpl<ConstantInt *> *Cases;
5743
+ SmallVectorImpl<ConstantInt *> *OtherCases;
5744
+ };
5745
+
5746
+ static std::optional<ContiguousCasesResult>
5747
+ findContiguousCases (Value *Condition, SmallVectorImpl<ConstantInt *> &Cases,
5748
+ SmallVectorImpl<ConstantInt *> &OtherCases,
5749
+ BasicBlock *Dest, BasicBlock *OtherDest) {
5738
5750
assert (Cases.size () >= 1 );
5739
5751
5740
5752
array_pod_sort (Cases.begin (), Cases.end (), constantIntSortPredicate);
5741
- for (size_t I = 1 , E = Cases.size (); I != E; ++I) {
5742
- if (Cases[I - 1 ]->getValue () != Cases[I]->getValue () + 1 )
5743
- return false ;
5753
+ const APInt &Min = Cases.back ()->getValue ();
5754
+ const APInt &Max = Cases.front ()->getValue ();
5755
+ APInt Offset = Max - Min;
5756
+ size_t ContiguousOffset = Cases.size () - 1 ;
5757
+ if (Offset == ContiguousOffset) {
5758
+ return ContiguousCasesResult{
5759
+ /* Min=*/ Cases.back (),
5760
+ /* Max=*/ Cases.front (),
5761
+ /* Dest=*/ Dest,
5762
+ /* OtherDest=*/ OtherDest,
5763
+ /* Cases=*/ &Cases,
5764
+ /* OtherCases=*/ &OtherCases,
5765
+ };
5744
5766
}
5745
- return true ;
5767
+ ConstantRange CR = computeConstantRange (Condition, /* ForSigned=*/ false );
5768
+ // If this is a wrapping contiguous range, that is, [Min, OtherMin] +
5769
+ // [OtherMax, Max] (also [OtherMax, OtherMin]), [OtherMin+1, OtherMax-1] is a
5770
+ // contiguous range for the other destination. N.B. If CR is not a full range,
5771
+ // Max+1 is not equal to Min. It's not continuous in arithmetic.
5772
+ if (Max == CR.getUnsignedMax () && Min == CR.getUnsignedMin ()) {
5773
+ assert (Cases.size () >= 2 );
5774
+ auto *It =
5775
+ std::adjacent_find (Cases.begin (), Cases.end (), [](auto L, auto R) {
5776
+ return L->getValue () != R->getValue () + 1 ;
5777
+ });
5778
+ if (It == Cases.end ())
5779
+ return std::nullopt ;
5780
+ auto [OtherMax, OtherMin] = std::make_pair (*It, *std::next (It));
5781
+ if ((Max - OtherMax->getValue ()) + (OtherMin->getValue () - Min) ==
5782
+ Cases.size () - 2 ) {
5783
+ return ContiguousCasesResult{
5784
+ /* Min=*/ cast<ConstantInt>(
5785
+ ConstantInt::get (OtherMin->getType (), OtherMin->getValue () + 1 )),
5786
+ /* Max=*/
5787
+ cast<ConstantInt>(
5788
+ ConstantInt::get (OtherMax->getType (), OtherMax->getValue () - 1 )),
5789
+ /* Dest=*/ OtherDest,
5790
+ /* OtherDest=*/ Dest,
5791
+ /* Cases=*/ &OtherCases,
5792
+ /* OtherCases=*/ &Cases,
5793
+ };
5794
+ }
5795
+ }
5796
+ return std::nullopt ;
5746
5797
}
5747
5798
5748
5799
static void createUnreachableSwitchDefault (SwitchInst *Switch,
@@ -5779,7 +5830,6 @@ bool SimplifyCFGOpt::turnSwitchRangeIntoICmp(SwitchInst *SI,
5779
5830
bool HasDefault = !SI->defaultDestUnreachable ();
5780
5831
5781
5832
auto *BB = SI->getParent ();
5782
-
5783
5833
// Partition the cases into two sets with different destinations.
5784
5834
BasicBlock *DestA = HasDefault ? SI->getDefaultDest () : nullptr ;
5785
5835
BasicBlock *DestB = nullptr ;
@@ -5813,37 +5863,62 @@ bool SimplifyCFGOpt::turnSwitchRangeIntoICmp(SwitchInst *SI,
5813
5863
assert (!CasesA.empty () || HasDefault);
5814
5864
5815
5865
// Figure out if one of the sets of cases form a contiguous range.
5816
- SmallVectorImpl<ConstantInt *> *ContiguousCases = nullptr ;
5817
- BasicBlock *ContiguousDest = nullptr ;
5818
- BasicBlock *OtherDest = nullptr ;
5819
- if (!CasesA.empty () && casesAreContiguous (CasesA)) {
5820
- ContiguousCases = &CasesA;
5821
- ContiguousDest = DestA;
5822
- OtherDest = DestB;
5823
- } else if (casesAreContiguous (CasesB)) {
5824
- ContiguousCases = &CasesB;
5825
- ContiguousDest = DestB;
5826
- OtherDest = DestA;
5827
- } else
5828
- return false ;
5866
+ std::optional<ContiguousCasesResult> ContiguousCases;
5867
+
5868
+ // Only one icmp is needed when there is only one case.
5869
+ if (!HasDefault && CasesA.size () == 1 )
5870
+ ContiguousCases = ContiguousCasesResult{
5871
+ /* Min=*/ CasesA[0 ],
5872
+ /* Max=*/ CasesA[0 ],
5873
+ /* Dest=*/ DestA,
5874
+ /* OtherDest=*/ DestB,
5875
+ /* Cases=*/ &CasesA,
5876
+ /* OtherCases=*/ &CasesB,
5877
+ };
5878
+ else if (CasesB.size () == 1 )
5879
+ ContiguousCases = ContiguousCasesResult{
5880
+ /* Min=*/ CasesB[0 ],
5881
+ /* Max=*/ CasesB[0 ],
5882
+ /* Dest=*/ DestB,
5883
+ /* OtherDest=*/ DestA,
5884
+ /* Cases=*/ &CasesB,
5885
+ /* OtherCases=*/ &CasesA,
5886
+ };
5887
+ // Correctness: Cases to the default destination cannot be contiguous cases.
5888
+ else if (!HasDefault)
5889
+ ContiguousCases =
5890
+ findContiguousCases (SI->getCondition (), CasesA, CasesB, DestA, DestB);
5829
5891
5830
- // Start building the compare and branch.
5892
+ if (!ContiguousCases)
5893
+ ContiguousCases =
5894
+ findContiguousCases (SI->getCondition (), CasesB, CasesA, DestB, DestA);
5831
5895
5832
- Constant *Offset = ConstantExpr::getNeg (ContiguousCases->back ());
5833
- Constant *NumCases =
5834
- ConstantInt::get (Offset->getType (), ContiguousCases->size ());
5896
+ if (!ContiguousCases)
5897
+ return false ;
5835
5898
5836
- Value *Sub = SI->getCondition ();
5837
- if (!Offset->isNullValue ())
5838
- Sub = Builder.CreateAdd (Sub, Offset, Sub->getName () + " .off" );
5899
+ auto [Min, Max, Dest, OtherDest, Cases, OtherCases] = *ContiguousCases;
5839
5900
5840
- Value *Cmp;
5901
+ // Start building the compare and branch.
5902
+
5903
+ Constant *Offset = ConstantExpr::getNeg (Min);
5904
+ Constant *NumCases = ConstantInt::get (Offset->getType (),
5905
+ Max->getValue () - Min->getValue () + 1 );
5906
+ BranchInst *NewBI;
5907
+ if (NumCases->isOneValue ()) {
5908
+ assert (Max->getValue () == Min->getValue ());
5909
+ Value *Cmp = Builder.CreateICmpEQ (SI->getCondition (), Min);
5910
+ NewBI = Builder.CreateCondBr (Cmp, Dest, OtherDest);
5911
+ }
5841
5912
// If NumCases overflowed, then all possible values jump to the successor.
5842
- if (NumCases->isNullValue () && !ContiguousCases->empty ())
5843
- Cmp = ConstantInt::getTrue (SI->getContext ());
5844
- else
5845
- Cmp = Builder.CreateICmpULT (Sub, NumCases, " switch" );
5846
- BranchInst *NewBI = Builder.CreateCondBr (Cmp, ContiguousDest, OtherDest);
5913
+ else if (NumCases->isNullValue () && !Cases->empty ()) {
5914
+ NewBI = Builder.CreateBr (Dest);
5915
+ } else {
5916
+ Value *Sub = SI->getCondition ();
5917
+ if (!Offset->isNullValue ())
5918
+ Sub = Builder.CreateAdd (Sub, Offset, Sub->getName () + " .off" );
5919
+ Value *Cmp = Builder.CreateICmpULT (Sub, NumCases, " switch" );
5920
+ NewBI = Builder.CreateCondBr (Cmp, Dest, OtherDest);
5921
+ }
5847
5922
5848
5923
// Update weight for the newly-created conditional branch.
5849
5924
if (hasBranchWeightMD (*SI)) {
@@ -5853,7 +5928,7 @@ bool SimplifyCFGOpt::turnSwitchRangeIntoICmp(SwitchInst *SI,
5853
5928
uint64_t TrueWeight = 0 ;
5854
5929
uint64_t FalseWeight = 0 ;
5855
5930
for (size_t I = 0 , E = Weights.size (); I != E; ++I) {
5856
- if (SI->getSuccessor (I) == ContiguousDest )
5931
+ if (SI->getSuccessor (I) == Dest )
5857
5932
TrueWeight += Weights[I];
5858
5933
else
5859
5934
FalseWeight += Weights[I];
@@ -5868,15 +5943,15 @@ bool SimplifyCFGOpt::turnSwitchRangeIntoICmp(SwitchInst *SI,
5868
5943
}
5869
5944
5870
5945
// Prune obsolete incoming values off the successors' PHI nodes.
5871
- for (auto BBI = ContiguousDest ->begin (); isa<PHINode>(BBI); ++BBI) {
5872
- unsigned PreviousEdges = ContiguousCases ->size ();
5873
- if (ContiguousDest == SI->getDefaultDest ())
5946
+ for (auto BBI = Dest ->begin (); isa<PHINode>(BBI); ++BBI) {
5947
+ unsigned PreviousEdges = Cases ->size ();
5948
+ if (Dest == SI->getDefaultDest ())
5874
5949
++PreviousEdges;
5875
5950
for (unsigned I = 0 , E = PreviousEdges - 1 ; I != E; ++I)
5876
5951
cast<PHINode>(BBI)->removeIncomingValue (SI->getParent ());
5877
5952
}
5878
5953
for (auto BBI = OtherDest->begin (); isa<PHINode>(BBI); ++BBI) {
5879
- unsigned PreviousEdges = SI-> getNumCases () - ContiguousCases ->size ();
5954
+ unsigned PreviousEdges = OtherCases ->size ();
5880
5955
if (OtherDest == SI->getDefaultDest ())
5881
5956
++PreviousEdges;
5882
5957
for (unsigned I = 0 , E = PreviousEdges - 1 ; I != E; ++I)
0 commit comments