Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
149 changes: 112 additions & 37 deletions llvm/lib/Transforms/Utils/SimplifyCFG.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -5718,15 +5718,66 @@ bool SimplifyCFGOpt::simplifyUnreachable(UnreachableInst *UI) {
return Changed;
}

static bool casesAreContiguous(SmallVectorImpl<ConstantInt *> &Cases) {
struct ContiguousCasesResult {
ConstantInt *Min;
ConstantInt *Max;
BasicBlock *Dest;
BasicBlock *OtherDest;
SmallVectorImpl<ConstantInt *> *Cases;
SmallVectorImpl<ConstantInt *> *OtherCases;
};

static std::optional<ContiguousCasesResult>
findContiguousCases(Value *Condition, SmallVectorImpl<ConstantInt *> &Cases,
SmallVectorImpl<ConstantInt *> &OtherCases,
BasicBlock *Dest, BasicBlock *OtherDest) {
assert(Cases.size() >= 1);

array_pod_sort(Cases.begin(), Cases.end(), constantIntSortPredicate);
for (size_t I = 1, E = Cases.size(); I != E; ++I) {
if (Cases[I - 1]->getValue() != Cases[I]->getValue() + 1)
return false;
const APInt &Min = Cases.back()->getValue();
const APInt &Max = Cases.front()->getValue();
APInt Offset = Max - Min;
size_t ContiguousOffset = Cases.size() - 1;
if (Offset == ContiguousOffset) {
return ContiguousCasesResult{
/*Min=*/Cases.back(),
/*Max=*/Cases.front(),
/*Dest=*/Dest,
/*OtherDest=*/OtherDest,
/*Cases=*/&Cases,
/*OtherCases=*/&OtherCases,
};
}
return true;
ConstantRange CR = computeConstantRange(Condition, /*ForSigned=*/false);
// If this is a wrapping contiguous range, that is, [Min, OtherMin] +
// [OtherMax, Max] (also [OtherMax, OtherMin]), [OtherMin+1, OtherMax-1] is a
// contiguous range for the other destination. N.B. If CR is not a full range,
// Max+1 is not equal to Min. It's not continuous in arithmetic.
if (Max == CR.getUnsignedMax() && Min == CR.getUnsignedMin()) {
assert(Cases.size() >= 2);
auto *It =
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

How about something as follows before the logic here?

// Find the first non-consecutive pair, and ensure this pair
// happens to be unique.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

(I meant this before the std::adjacent_find, which is not appearing)

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this is what you want. std::adjacent_find finds one non-consecutive pair and checks their distance to check if this pair is unique.

std::adjacent_find(Cases.begin(), Cases.end(), [](auto L, auto R) {
return L->getValue() != R->getValue() + 1;
});
if (It == Cases.end())
return std::nullopt;
auto [OtherMax, OtherMin] = std::make_pair(*It, *std::next(It));
if ((Max - OtherMax->getValue()) + (OtherMin->getValue() - Min) ==
Cases.size() - 2) {
return ContiguousCasesResult{
/*Min=*/cast<ConstantInt>(
ConstantInt::get(OtherMin->getType(), OtherMin->getValue() + 1)),
/*Max=*/
cast<ConstantInt>(
ConstantInt::get(OtherMax->getType(), OtherMax->getValue() - 1)),
/*Dest=*/OtherDest,
/*OtherDest=*/Dest,
/*Cases=*/&OtherCases,
/*OtherCases=*/&Cases,
};
}
}
return std::nullopt;
}

static void createUnreachableSwitchDefault(SwitchInst *Switch,
Expand Down Expand Up @@ -5763,7 +5814,6 @@ bool SimplifyCFGOpt::turnSwitchRangeIntoICmp(SwitchInst *SI,
bool HasDefault = !SI->defaultDestUnreachable();

auto *BB = SI->getParent();

// Partition the cases into two sets with different destinations.
BasicBlock *DestA = HasDefault ? SI->getDefaultDest() : nullptr;
BasicBlock *DestB = nullptr;
Expand Down Expand Up @@ -5797,37 +5847,62 @@ bool SimplifyCFGOpt::turnSwitchRangeIntoICmp(SwitchInst *SI,
assert(!CasesA.empty() || HasDefault);

// Figure out if one of the sets of cases form a contiguous range.
SmallVectorImpl<ConstantInt *> *ContiguousCases = nullptr;
BasicBlock *ContiguousDest = nullptr;
BasicBlock *OtherDest = nullptr;
if (!CasesA.empty() && casesAreContiguous(CasesA)) {
ContiguousCases = &CasesA;
ContiguousDest = DestA;
OtherDest = DestB;
} else if (casesAreContiguous(CasesB)) {
ContiguousCases = &CasesB;
ContiguousDest = DestB;
OtherDest = DestA;
} else
return false;
std::optional<ContiguousCasesResult> ContiguousCases;

// Only one icmp is needed when there is only one case.
if (!HasDefault && CasesA.size() == 1)
ContiguousCases = ContiguousCasesResult{
/*Min=*/CasesA[0],
/*Max=*/CasesA[0],
/*Dest=*/DestA,
/*OtherDest=*/DestB,
/*Cases=*/&CasesA,
/*OtherCases=*/&CasesB,
};
else if (CasesB.size() == 1)
ContiguousCases = ContiguousCasesResult{
/*Min=*/CasesB[0],
/*Max=*/CasesB[0],
/*Dest=*/DestB,
/*OtherDest=*/DestA,
/*Cases=*/&CasesB,
/*OtherCases=*/&CasesA,
};
// Correctness: Cases to the default destination cannot be contiguous cases.
else if (!HasDefault)
ContiguousCases =
findContiguousCases(SI->getCondition(), CasesA, CasesB, DestA, DestB);

// Start building the compare and branch.
if (!ContiguousCases)
ContiguousCases =
findContiguousCases(SI->getCondition(), CasesB, CasesA, DestB, DestA);

Constant *Offset = ConstantExpr::getNeg(ContiguousCases->back());
Constant *NumCases =
ConstantInt::get(Offset->getType(), ContiguousCases->size());
if (!ContiguousCases)
return false;

Value *Sub = SI->getCondition();
if (!Offset->isNullValue())
Sub = Builder.CreateAdd(Sub, Offset, Sub->getName() + ".off");
auto [Min, Max, Dest, OtherDest, Cases, OtherCases] = *ContiguousCases;

Value *Cmp;
// Start building the compare and branch.

Constant *Offset = ConstantExpr::getNeg(Min);
Constant *NumCases = ConstantInt::get(Offset->getType(),
Max->getValue() - Min->getValue() + 1);
BranchInst *NewBI;
if (NumCases->isOneValue()) {
assert(Max->getValue() == Min->getValue());
Value *Cmp = Builder.CreateICmpEQ(SI->getCondition(), Min);
NewBI = Builder.CreateCondBr(Cmp, Dest, OtherDest);
}
// If NumCases overflowed, then all possible values jump to the successor.
if (NumCases->isNullValue() && !ContiguousCases->empty())
Cmp = ConstantInt::getTrue(SI->getContext());
else
Cmp = Builder.CreateICmpULT(Sub, NumCases, "switch");
BranchInst *NewBI = Builder.CreateCondBr(Cmp, ContiguousDest, OtherDest);
else if (NumCases->isNullValue() && !Cases->empty()) {
NewBI = Builder.CreateBr(Dest);
} else {
Value *Sub = SI->getCondition();
if (!Offset->isNullValue())
Sub = Builder.CreateAdd(Sub, Offset, Sub->getName() + ".off");
Value *Cmp = Builder.CreateICmpULT(Sub, NumCases, "switch");
NewBI = Builder.CreateCondBr(Cmp, Dest, OtherDest);
}

// Update weight for the newly-created conditional branch.
if (hasBranchWeightMD(*SI)) {
Expand All @@ -5837,7 +5912,7 @@ bool SimplifyCFGOpt::turnSwitchRangeIntoICmp(SwitchInst *SI,
uint64_t TrueWeight = 0;
uint64_t FalseWeight = 0;
for (size_t I = 0, E = Weights.size(); I != E; ++I) {
if (SI->getSuccessor(I) == ContiguousDest)
if (SI->getSuccessor(I) == Dest)
TrueWeight += Weights[I];
else
FalseWeight += Weights[I];
Expand All @@ -5852,15 +5927,15 @@ bool SimplifyCFGOpt::turnSwitchRangeIntoICmp(SwitchInst *SI,
}

// Prune obsolete incoming values off the successors' PHI nodes.
for (auto BBI = ContiguousDest->begin(); isa<PHINode>(BBI); ++BBI) {
unsigned PreviousEdges = ContiguousCases->size();
if (ContiguousDest == SI->getDefaultDest())
for (auto BBI = Dest->begin(); isa<PHINode>(BBI); ++BBI) {
unsigned PreviousEdges = Cases->size();
if (Dest == SI->getDefaultDest())
++PreviousEdges;
for (unsigned I = 0, E = PreviousEdges - 1; I != E; ++I)
cast<PHINode>(BBI)->removeIncomingValue(SI->getParent());
}
for (auto BBI = OtherDest->begin(); isa<PHINode>(BBI); ++BBI) {
unsigned PreviousEdges = SI->getNumCases() - ContiguousCases->size();
unsigned PreviousEdges = OtherCases->size();
if (OtherDest == SI->getDefaultDest())
++PreviousEdges;
for (unsigned I = 0, E = PreviousEdges - 1; I != E; ++I)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -80,8 +80,8 @@ cleanup2:
; CHECK: cleanup2.corodispatch:
; CHECK: %1 = phi i8 [ 0, %handler2 ], [ 1, %catch.dispatch.2 ]
; CHECK: %2 = cleanuppad within %h1 []
; CHECK: %switch = icmp ult i8 %1, 1
; CHECK: br i1 %switch, label %cleanup2.from.handler2, label %cleanup2.from.catch.dispatch.2
; CHECK: %3 = icmp eq i8 %1, 0
; CHECK: br i1 %3, label %cleanup2.from.handler2, label %cleanup2.from.catch.dispatch.2

; CHECK: cleanup2.from.handler2:
; CHECK: %valueB.reload = load i32, ptr %valueB.spill.addr, align 4
Expand Down
12 changes: 4 additions & 8 deletions llvm/test/Transforms/SimplifyCFG/switch-dead-default.ll
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,7 @@ declare void @foo(i32)
define void @test(i1 %a) {
; CHECK-LABEL: define void @test(
; CHECK-SAME: i1 [[A:%.*]]) {
; CHECK-NEXT: [[A_OFF:%.*]] = add i1 [[A]], true
; CHECK-NEXT: [[SWITCH:%.*]] = icmp ult i1 [[A_OFF]], true
; CHECK-NEXT: [[SWITCH:%.*]] = icmp eq i1 [[A]], true
; CHECK-NEXT: br i1 [[SWITCH]], label [[TRUE:%.*]], label [[FALSE:%.*]]
; CHECK: common.ret:
; CHECK-NEXT: ret void
Expand Down Expand Up @@ -209,8 +208,7 @@ define void @test5(i8 %a) {
; CHECK-SAME: i8 [[A:%.*]]) {
; CHECK-NEXT: [[CMP:%.*]] = icmp ult i8 [[A]], 2
; CHECK-NEXT: call void @llvm.assume(i1 [[CMP]])
; CHECK-NEXT: [[A_OFF:%.*]] = add i8 [[A]], -1
; CHECK-NEXT: [[SWITCH:%.*]] = icmp ult i8 [[A_OFF]], 1
; CHECK-NEXT: [[SWITCH:%.*]] = icmp eq i8 [[A]], 1
; CHECK-NEXT: br i1 [[SWITCH]], label [[TRUE:%.*]], label [[FALSE:%.*]]
; CHECK: common.ret:
; CHECK-NEXT: ret void
Expand Down Expand Up @@ -243,8 +241,7 @@ define void @test6(i8 %a) {
; CHECK-NEXT: [[AND:%.*]] = and i8 [[A]], -2
; CHECK-NEXT: [[CMP:%.*]] = icmp eq i8 [[AND]], -2
; CHECK-NEXT: call void @llvm.assume(i1 [[CMP]])
; CHECK-NEXT: [[A_OFF:%.*]] = add i8 [[A]], 1
; CHECK-NEXT: [[SWITCH:%.*]] = icmp ult i8 [[A_OFF]], 1
; CHECK-NEXT: [[SWITCH:%.*]] = icmp eq i8 [[A]], -1
; CHECK-NEXT: br i1 [[SWITCH]], label [[TRUE:%.*]], label [[FALSE:%.*]]
; CHECK: common.ret:
; CHECK-NEXT: ret void
Expand Down Expand Up @@ -279,8 +276,7 @@ define void @test7(i8 %a) {
; CHECK-NEXT: [[AND:%.*]] = and i8 [[A]], -2
; CHECK-NEXT: [[CMP:%.*]] = icmp eq i8 [[AND]], -2
; CHECK-NEXT: call void @llvm.assume(i1 [[CMP]])
; CHECK-NEXT: [[A_OFF:%.*]] = add i8 [[A]], 1
; CHECK-NEXT: [[SWITCH:%.*]] = icmp ult i8 [[A_OFF]], 1
; CHECK-NEXT: [[SWITCH:%.*]] = icmp eq i8 [[A]], -1
; CHECK-NEXT: br i1 [[SWITCH]], label [[TRUE:%.*]], label [[FALSE:%.*]]
; CHECK: common.ret:
; CHECK-NEXT: ret void
Expand Down
Loading