@@ -7131,6 +7131,119 @@ static bool simplifySwitchOfPowersOfTwo(SwitchInst *SI, IRBuilder<> &Builder,
71317131 return true ;
71327132}
71337133
7134+ // / Fold switch over ucmp/scmp intrinsic to br if two of the switch arms have
7135+ // / the same destination.
7136+ static bool simplifySwitchOfCmpIntrinsic (SwitchInst *SI, IRBuilderBase &Builder,
7137+ DomTreeUpdater *DTU) {
7138+ auto *Cmp = dyn_cast<CmpIntrinsic>(SI->getCondition ());
7139+ if (!Cmp || !Cmp->hasOneUse ())
7140+ return false ;
7141+
7142+ SmallVector<uint32_t , 4 > Weights;
7143+ bool HasWeights = extractBranchWeights (getBranchWeightMDNode (*SI), Weights);
7144+ if (!HasWeights)
7145+ Weights.resize (4 ); // Avoid checking HasWeights everywhere.
7146+
7147+ // Normalize to [us]cmp == Res ? Succ : OtherSucc.
7148+ int64_t Res;
7149+ BasicBlock *Succ, *OtherSucc;
7150+ uint32_t SuccWeight = 0 , OtherSuccWeight = 0 ;
7151+ BasicBlock *Unreachable = nullptr ;
7152+
7153+ if (SI->getNumCases () == 2 ) {
7154+ // Find which of 1, 0 or -1 is missing (handled by default dest).
7155+ SmallSet<int64_t , 3 > Missing;
7156+ Missing.insert (1 );
7157+ Missing.insert (0 );
7158+ Missing.insert (-1 );
7159+
7160+ Succ = SI->getDefaultDest ();
7161+ SuccWeight = Weights[0 ];
7162+ OtherSucc = nullptr ;
7163+ for (auto &Case : SI->cases ()) {
7164+ std::optional<int64_t > Val =
7165+ Case.getCaseValue ()->getValue ().trySExtValue ();
7166+ if (!Val)
7167+ return false ;
7168+ if (!Missing.erase (*Val))
7169+ return false ;
7170+ if (OtherSucc && OtherSucc != Case.getCaseSuccessor ())
7171+ return false ;
7172+ OtherSucc = Case.getCaseSuccessor ();
7173+ OtherSuccWeight += Weights[Case.getSuccessorIndex ()];
7174+ }
7175+
7176+ assert (Missing.size () == 1 && " Should have one case left" );
7177+ Res = *Missing.begin ();
7178+ } else if (SI->getNumCases () == 3 && SI->defaultDestUndefined ()) {
7179+ // Normalize so that Succ is taken once and OtherSucc twice.
7180+ Unreachable = SI->getDefaultDest ();
7181+ Succ = OtherSucc = nullptr ;
7182+ for (auto &Case : SI->cases ()) {
7183+ BasicBlock *NewSucc = Case.getCaseSuccessor ();
7184+ uint32_t Weight = Weights[Case.getSuccessorIndex ()];
7185+ if (!OtherSucc || OtherSucc == NewSucc) {
7186+ OtherSucc = NewSucc;
7187+ OtherSuccWeight += Weight;
7188+ } else if (!Succ) {
7189+ Succ = NewSucc;
7190+ SuccWeight = Weight;
7191+ } else if (Succ == NewSucc) {
7192+ std::swap (Succ, OtherSucc);
7193+ std::swap (SuccWeight, OtherSuccWeight);
7194+ } else
7195+ return false ;
7196+ }
7197+ for (auto &Case : SI->cases ()) {
7198+ std::optional<int64_t > Val =
7199+ Case.getCaseValue ()->getValue ().trySExtValue ();
7200+ if (!Val || (Val != 1 && Val != 0 && Val != -1 ))
7201+ return false ;
7202+ if (Case.getCaseSuccessor () == Succ) {
7203+ Res = *Val;
7204+ break ;
7205+ }
7206+ }
7207+ } else {
7208+ return false ;
7209+ }
7210+
7211+ // Determine predicate for the missing case.
7212+ ICmpInst::Predicate Pred;
7213+ switch (Res) {
7214+ case 1 :
7215+ Pred = ICmpInst::ICMP_UGT;
7216+ break ;
7217+ case 0 :
7218+ Pred = ICmpInst::ICMP_EQ;
7219+ break ;
7220+ case -1 :
7221+ Pred = ICmpInst::ICMP_ULT;
7222+ break ;
7223+ }
7224+ if (Cmp->isSigned ())
7225+ Pred = ICmpInst::getSignedPredicate (Pred);
7226+
7227+ MDNode *NewWeights = nullptr ;
7228+ if (HasWeights)
7229+ NewWeights = MDBuilder (SI->getContext ())
7230+ .createBranchWeights (SuccWeight, OtherSuccWeight);
7231+
7232+ BasicBlock *BB = SI->getParent ();
7233+ Builder.SetInsertPoint (SI->getIterator ());
7234+ Value *ICmp = Builder.CreateICmp (Pred, Cmp->getLHS (), Cmp->getRHS ());
7235+ Builder.CreateCondBr (ICmp, Succ, OtherSucc, NewWeights,
7236+ SI->getMetadata (LLVMContext::MD_unpredictable));
7237+ OtherSucc->removePredecessor (BB);
7238+ if (Unreachable)
7239+ Unreachable->removePredecessor (BB);
7240+ SI->eraseFromParent ();
7241+ Cmp->eraseFromParent ();
7242+ if (DTU && Unreachable)
7243+ DTU->applyUpdates ({{DominatorTree::Delete, BB, Unreachable}});
7244+ return true ;
7245+ }
7246+
71347247bool SimplifyCFGOpt::simplifySwitch (SwitchInst *SI, IRBuilder<> &Builder) {
71357248 BasicBlock *BB = SI->getParent ();
71367249
@@ -7163,6 +7276,9 @@ bool SimplifyCFGOpt::simplifySwitch(SwitchInst *SI, IRBuilder<> &Builder) {
71637276 if (eliminateDeadSwitchCases (SI, DTU, Options.AC , DL))
71647277 return requestResimplify ();
71657278
7279+ if (simplifySwitchOfCmpIntrinsic (SI, Builder, DTU))
7280+ return requestResimplify ();
7281+
71667282 if (trySwitchToSelect (SI, Builder, DTU, DL, TTI))
71677283 return requestResimplify ();
71687284
0 commit comments