Skip to content

Commit 5b4197c

Browse files
committed
[SwitchLowering] Support merging 0 and power-of-2 case.
1 parent d42f5eb commit 5b4197c

File tree

6 files changed

+211
-204
lines changed

6 files changed

+211
-204
lines changed

llvm/include/llvm/CodeGen/GlobalISel/IRTranslator.h

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -405,13 +405,13 @@ class IRTranslator : public MachineFunctionPass {
405405
BranchProbability UnhandledProbs, SwitchCG::CaseClusterIt I,
406406
MachineBasicBlock *Fallthrough, bool FallthroughUnreachable);
407407

408-
bool lowerSwitchRangeWorkItem(SwitchCG::CaseClusterIt I, Value *Cond,
409-
MachineBasicBlock *Fallthrough,
410-
bool FallthroughUnreachable,
411-
BranchProbability UnhandledProbs,
412-
MachineBasicBlock *CurMBB,
413-
MachineIRBuilder &MIB,
414-
MachineBasicBlock *SwitchMBB);
408+
bool lowerSwitchAndOrRangeWorkItem(SwitchCG::CaseClusterIt I, Value *Cond,
409+
MachineBasicBlock *Fallthrough,
410+
bool FallthroughUnreachable,
411+
BranchProbability UnhandledProbs,
412+
MachineBasicBlock *CurMBB,
413+
MachineIRBuilder &MIB,
414+
MachineBasicBlock *SwitchMBB);
415415

416416
bool lowerBitTestWorkItem(
417417
SwitchCG::SwitchWorkListItem W, MachineBasicBlock *SwitchMBB,

llvm/include/llvm/CodeGen/SwitchLoweringUtils.h

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,8 @@ enum CaseClusterKind {
3535
/// A cluster of cases suitable for jump table lowering.
3636
CC_JumpTable,
3737
/// A cluster of cases suitable for bit test lowering.
38-
CC_BitTests
38+
CC_BitTests,
39+
CC_And
3940
};
4041

4142
/// A cluster of case labels.
@@ -141,6 +142,8 @@ struct CaseBlock {
141142
BranchProbability TrueProb, FalseProb;
142143
bool IsUnpredictable;
143144

145+
bool EmitAnd = false;
146+
144147
// Constructor for SelectionDAG.
145148
CaseBlock(ISD::CondCode cc, const Value *cmplhs, const Value *cmprhs,
146149
const Value *cmpmiddle, MachineBasicBlock *truebb,

llvm/lib/CodeGen/GlobalISel/IRTranslator.cpp

Lines changed: 11 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1058,18 +1058,15 @@ bool IRTranslator::lowerJumpTableWorkItem(SwitchCG::SwitchWorkListItem W,
10581058
}
10591059
return true;
10601060
}
1061-
bool IRTranslator::lowerSwitchRangeWorkItem(SwitchCG::CaseClusterIt I,
1062-
Value *Cond,
1063-
MachineBasicBlock *Fallthrough,
1064-
bool FallthroughUnreachable,
1065-
BranchProbability UnhandledProbs,
1066-
MachineBasicBlock *CurMBB,
1067-
MachineIRBuilder &MIB,
1068-
MachineBasicBlock *SwitchMBB) {
1061+
bool IRTranslator::lowerSwitchAndOrRangeWorkItem(
1062+
SwitchCG::CaseClusterIt I, Value *Cond, MachineBasicBlock *Fallthrough,
1063+
bool FallthroughUnreachable, BranchProbability UnhandledProbs,
1064+
MachineBasicBlock *CurMBB, MachineIRBuilder &MIB,
1065+
MachineBasicBlock *SwitchMBB) {
10691066
using namespace SwitchCG;
10701067
const Value *RHS, *LHS, *MHS;
10711068
CmpInst::Predicate Pred;
1072-
if (I->Low == I->High) {
1069+
if (I->Low == I->High || I->Kind == CC_And) {
10731070
// Check Cond == I->Low.
10741071
Pred = CmpInst::ICMP_EQ;
10751072
LHS = Cond;
@@ -1087,6 +1084,7 @@ bool IRTranslator::lowerSwitchRangeWorkItem(SwitchCG::CaseClusterIt I,
10871084
// The false probability is the sum of all unhandled cases.
10881085
CaseBlock CB(Pred, FallthroughUnreachable, LHS, RHS, MHS, I->MBB, Fallthrough,
10891086
CurMBB, MIB.getDebugLoc(), I->Prob, UnhandledProbs);
1087+
CB.EmitAnd = I->Kind == CC_And;
10901088

10911089
emitSwitchCase(CB, SwitchMBB, MIB);
10921090
return true;
@@ -1326,10 +1324,11 @@ bool IRTranslator::lowerSwitchWorkItem(SwitchCG::SwitchWorkListItem W,
13261324
}
13271325
break;
13281326
}
1327+
case CC_And:
13291328
case CC_Range: {
1330-
if (!lowerSwitchRangeWorkItem(I, Cond, Fallthrough,
1331-
FallthroughUnreachable, UnhandledProbs,
1332-
CurMBB, MIB, SwitchMBB)) {
1329+
if (!lowerSwitchAndOrRangeWorkItem(I, Cond, Fallthrough,
1330+
FallthroughUnreachable, UnhandledProbs,
1331+
CurMBB, MIB, SwitchMBB)) {
13331332
LLVM_DEBUG(dbgs() << "Failed to lower switch range");
13341333
return false;
13351334
}

llvm/lib/CodeGen/SelectionDAG/SelectionDAGBuilder.cpp

Lines changed: 14 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2887,7 +2887,17 @@ void SelectionDAGBuilder::visitSwitchCase(CaseBlock &CB,
28872887
EVT MemVT = TLI.getMemValueType(DAG.getDataLayout(), CB.CmpLHS->getType());
28882888

28892889
// Build the setcc now.
2890-
if (!CB.CmpMHS) {
2890+
if (CB.EmitAnd) {
2891+
SDLoc dl = getCurSDLoc();
2892+
2893+
const TargetLowering &TLI = DAG.getTargetLoweringInfo();
2894+
EVT VT = TLI.getValueType(DAG.getDataLayout(), CB.CmpRHS->getType(), true);
2895+
SDValue C = DAG.getConstant(*cast<ConstantInt>(CB.CmpRHS), dl, VT);
2896+
SDValue Zero = DAG.getConstant(0, dl, VT);
2897+
SDValue CondLHS = getValue(CB.CmpLHS);
2898+
SDValue And = DAG.getNode(ISD::AND, dl, C.getValueType(), CondLHS, C);
2899+
Cond = DAG.getSetCC(dl, MVT::i1, And, Zero, ISD::SETEQ);
2900+
} else if (!CB.CmpMHS) {
28912901
// Fold "(X == true)" to X and "(X == false)" to !X to
28922902
// handle common cases produced by branch lowering.
28932903
if (CB.CmpRHS == ConstantInt::getTrue(*DAG.getContext()) &&
@@ -12308,10 +12318,11 @@ void SelectionDAGBuilder::lowerWorkItem(SwitchWorkListItem W, Value *Cond,
1230812318
}
1230912319
break;
1231012320
}
12321+
case CC_And:
1231112322
case CC_Range: {
1231212323
const Value *RHS, *LHS, *MHS;
1231312324
ISD::CondCode CC;
12314-
if (I->Low == I->High) {
12325+
if (I->Low == I->High || I->Kind == CC_And) {
1231512326
// Check Cond == I->Low.
1231612327
CC = ISD::SETEQ;
1231712328
LHS = Cond;
@@ -12333,6 +12344,7 @@ void SelectionDAGBuilder::lowerWorkItem(SwitchWorkListItem W, Value *Cond,
1233312344
CaseBlock CB(CC, LHS, RHS, MHS, I->MBB, Fallthrough, CurMBB,
1233412345
getCurSDLoc(), I->Prob, UnhandledProbs);
1233512346

12347+
CB.EmitAnd = I->Kind == CC_And;
1233612348
if (CurMBB == SwitchMBB)
1233712349
visitSwitchCase(CB, SwitchMBB);
1233812350
else

llvm/lib/CodeGen/SwitchLoweringUtils.cpp

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -362,6 +362,41 @@ void SwitchCG::SwitchLowering::findBitTestClusters(CaseClusterVector &Clusters,
362362
}
363363
}
364364
Clusters.resize(DstIndex);
365+
366+
// Check if the clusters contain one checking for 0 and another one checking
367+
// for a power-of-2 constant with matching destinations. Those clusters can be
368+
// combined to a single ane with CC_And.
369+
unsigned ZeroIdx = -1;
370+
for (const auto &[Idx, C] : enumerate(Clusters)) {
371+
if (C.Kind != CC_Range || C.Low != C.High)
372+
continue;
373+
if (C.Low->isZero()) {
374+
ZeroIdx = Idx;
375+
break;
376+
}
377+
}
378+
if (ZeroIdx == -1u)
379+
return;
380+
381+
unsigned Pow2Idx = -1;
382+
for (const auto &[Idx, C] : enumerate(Clusters)) {
383+
if (C.Kind != CC_Range || C.Low != C.High || C.MBB != Clusters[ZeroIdx].MBB)
384+
continue;
385+
if (C.Low->getValue().isPowerOf2()) {
386+
Pow2Idx = Idx;
387+
break;
388+
}
389+
}
390+
if (Pow2Idx == -1u)
391+
return;
392+
393+
APInt Pow2 = Clusters[Pow2Idx].Low->getValue();
394+
APInt NewC = (Pow2 + 1) * -1;
395+
Clusters[ZeroIdx].Low = ConstantInt::get(SI->getContext(), NewC);
396+
Clusters[ZeroIdx].High = ConstantInt::get(SI->getContext(), NewC);
397+
Clusters[ZeroIdx].Kind = CC_And;
398+
Clusters[ZeroIdx].Prob += Clusters[Pow2Idx].Prob;
399+
Clusters.erase(Clusters.begin() + Pow2Idx);
365400
}
366401

367402
bool SwitchCG::SwitchLowering::buildBitTests(CaseClusterVector &Clusters,

0 commit comments

Comments
 (0)