Skip to content
Merged
Show file tree
Hide file tree
Changes from 8 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
113 changes: 89 additions & 24 deletions llvm/lib/Transforms/Utils/SimplifyCFG.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -5718,15 +5718,49 @@ bool SimplifyCFGOpt::simplifyUnreachable(UnreachableInst *UI) {
return Changed;
}

static bool casesAreContiguous(SmallVectorImpl<ConstantInt *> &Cases) {
static bool casesAreContiguous(Value *Condition,
SmallVectorImpl<ConstantInt *> &Cases,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think the function name would deserve an update, e.g., findContiguousCases (and maybe return a std::optional struct instead of the three values as out parameters)?

ConstantInt *&ContiguousCasesMin,
ConstantInt *&ContiguousCasesMax,
bool &IsWrapping) {
assert(Cases.size() >= 1);

array_pod_sort(Cases.begin(), Cases.end(), constantIntSortPredicate);
for (size_t I = 1, E = Cases.size(); I != E; ++I) {
if (Cases[I - 1]->getValue() != Cases[I]->getValue() + 1)
auto Min = Cases.back()->getValue();
auto Max = Cases.front()->getValue();
auto Offset = Max - Min;
auto ContiguousOffset = Cases.size() - 1;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No need for auto here, having the type explicit (const APInt&, unsigned) would make the code more readable.

if (Offset == ContiguousOffset) {
ContiguousCasesMin = Cases.back();
ContiguousCasesMax = Cases.front();
return true;
}
ConstantRange CR = computeConstantRange(Condition, /*ForSigned=*/false);
// If this is a wrapping contiguous range, that is, [Min, OtherMin] +
// [OtherMax, Max] (also [OtherMax, OtherMin]), [OtherMin+1, OtherMax-1] is a
// contiguous range for the other destination. N.B. If CR is not a full range,
// Max+1 is not equal to Min. It's not continuous in arithmetic.
if (Max == CR.getUnsignedMax() && Min == CR.getUnsignedMin()) {
assert(Cases.size() >= 2);
auto *It =
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

How about something as follows before the logic here?

// Find the first non-consecutive pair, and ensure this pair
// happens to be unique.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

(I meant this before the std::adjacent_find, which is not appearing)

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this is what you want. std::adjacent_find finds one non-consecutive pair and checks their distance to check if this pair is unique.

std::adjacent_find(Cases.begin(), Cases.end(), [](auto L, auto R) {
return L->getValue() != R->getValue() + 1;
});
if (It == Cases.end())
return false;
auto *OtherMax = *It;
auto *OtherMin = *(It + 1);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
auto *OtherMax = *It;
auto *OtherMin = *(It + 1);
auto [OtherMax, OtherMin] = std::make_pair(*It, *std::next(It));

if ((Max - OtherMax->getValue()) + (OtherMin->getValue() - Min) ==
Cases.size() - 2) {
ContiguousCasesMin = cast<ConstantInt>(
ConstantInt::get(OtherMin->getType(), OtherMin->getValue() + 1));
ContiguousCasesMax = cast<ConstantInt>(
ConstantInt::get(OtherMax->getType(), OtherMax->getValue() - 1));
IsWrapping = true;
return true;
}
}
return true;
return false;
}

static void createUnreachableSwitchDefault(SwitchInst *Switch,
Expand Down Expand Up @@ -5797,37 +5831,68 @@ bool SimplifyCFGOpt::turnSwitchRangeIntoICmp(SwitchInst *SI,
assert(!CasesA.empty() || HasDefault);

// Figure out if one of the sets of cases form a contiguous range.
SmallVectorImpl<ConstantInt *> *ContiguousCases = nullptr;
ConstantInt *ContiguousCasesMin = nullptr;
ConstantInt *ContiguousCasesMax = nullptr;
BasicBlock *ContiguousDest = nullptr;
BasicBlock *OtherDest = nullptr;
if (!CasesA.empty() && casesAreContiguous(CasesA)) {
ContiguousCases = &CasesA;
bool IsWrapping = false;
SmallVectorImpl<ConstantInt *> *ContiguousCases = &CasesA;
SmallVectorImpl<ConstantInt *> *OtherCases = &CasesB;

// Only one icmp is needed when there is only one case.
if (!HasDefault && CasesA.size() == 1) {
ContiguousCasesMax = CasesA[0];
ContiguousCasesMin = CasesA[0];
ContiguousDest = DestA;
OtherDest = DestB;
} else if (casesAreContiguous(CasesB)) {
ContiguousCases = &CasesB;
} else if (CasesB.size() == 1) {
ContiguousCasesMax = CasesB[0];
ContiguousCasesMin = CasesB[0];
ContiguousDest = DestB;
OtherDest = DestA;
std::swap(ContiguousCases, OtherCases);
}
// Correctness: Cases to the default destination cannot be contiguous cases.
else if (!HasDefault && !CasesA.empty() &&
casesAreContiguous(SI->getCondition(), CasesA, ContiguousCasesMin,
ContiguousCasesMax, IsWrapping)) {
ContiguousDest = DestA;
OtherDest = DestB;
} else if (casesAreContiguous(SI->getCondition(), CasesB, ContiguousCasesMin,
ContiguousCasesMax, IsWrapping)) {
ContiguousDest = DestB;
OtherDest = DestA;
std::swap(ContiguousCases, OtherCases);
} else
return false;

// Start building the compare and branch.

Constant *Offset = ConstantExpr::getNeg(ContiguousCases->back());
Constant *NumCases =
ConstantInt::get(Offset->getType(), ContiguousCases->size());
if (IsWrapping) {
std::swap(ContiguousDest, OtherDest);
std::swap(ContiguousCases, OtherCases);
}

Value *Sub = SI->getCondition();
if (!Offset->isNullValue())
Sub = Builder.CreateAdd(Sub, Offset, Sub->getName() + ".off");
// Start building the compare and branch.

Value *Cmp;
Constant *Offset = ConstantExpr::getNeg(ContiguousCasesMin);
Constant *NumCases = ConstantInt::get(Offset->getType(),
ContiguousCasesMax->getValue() -
ContiguousCasesMin->getValue() + 1);
BranchInst *NewBI;
if (NumCases->isOneValue()) {
assert(ContiguousCasesMax->getValue() == ContiguousCasesMin->getValue());
Value *Cmp = Builder.CreateICmpEQ(SI->getCondition(), ContiguousCasesMin);
NewBI = Builder.CreateCondBr(Cmp, ContiguousDest, OtherDest);
}
// If NumCases overflowed, then all possible values jump to the successor.
if (NumCases->isNullValue() && !ContiguousCases->empty())
Cmp = ConstantInt::getTrue(SI->getContext());
else
Cmp = Builder.CreateICmpULT(Sub, NumCases, "switch");
BranchInst *NewBI = Builder.CreateCondBr(Cmp, ContiguousDest, OtherDest);
else if (NumCases->isNullValue() && !ContiguousCases->empty()) {
NewBI = Builder.CreateBr(ContiguousDest);
} else {
Value *Sub = SI->getCondition();
if (!Offset->isNullValue())
Sub = Builder.CreateAdd(Sub, Offset, Sub->getName() + ".off");
Value *Cmp = Builder.CreateICmpULT(Sub, NumCases, "switch");
NewBI = Builder.CreateCondBr(Cmp, ContiguousDest, OtherDest);
}

// Update weight for the newly-created conditional branch.
if (hasBranchWeightMD(*SI)) {
Expand Down Expand Up @@ -5860,7 +5925,7 @@ bool SimplifyCFGOpt::turnSwitchRangeIntoICmp(SwitchInst *SI,
cast<PHINode>(BBI)->removeIncomingValue(SI->getParent());
}
for (auto BBI = OtherDest->begin(); isa<PHINode>(BBI); ++BBI) {
unsigned PreviousEdges = SI->getNumCases() - ContiguousCases->size();
unsigned PreviousEdges = OtherCases->size();
if (OtherDest == SI->getDefaultDest())
++PreviousEdges;
for (unsigned I = 0, E = PreviousEdges - 1; I != E; ++I)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -80,8 +80,8 @@ cleanup2:
; CHECK: cleanup2.corodispatch:
; CHECK: %1 = phi i8 [ 0, %handler2 ], [ 1, %catch.dispatch.2 ]
; CHECK: %2 = cleanuppad within %h1 []
; CHECK: %switch = icmp ult i8 %1, 1
; CHECK: br i1 %switch, label %cleanup2.from.handler2, label %cleanup2.from.catch.dispatch.2
; CHECK: %3 = icmp eq i8 %1, 0
; CHECK: br i1 %3, label %cleanup2.from.handler2, label %cleanup2.from.catch.dispatch.2

; CHECK: cleanup2.from.handler2:
; CHECK: %valueB.reload = load i32, ptr %valueB.spill.addr, align 4
Expand Down
12 changes: 4 additions & 8 deletions llvm/test/Transforms/SimplifyCFG/switch-dead-default.ll
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,7 @@ declare void @foo(i32)
define void @test(i1 %a) {
; CHECK-LABEL: define void @test(
; CHECK-SAME: i1 [[A:%.*]]) {
; CHECK-NEXT: [[A_OFF:%.*]] = add i1 [[A]], true
; CHECK-NEXT: [[SWITCH:%.*]] = icmp ult i1 [[A_OFF]], true
; CHECK-NEXT: [[SWITCH:%.*]] = icmp eq i1 [[A]], true
; CHECK-NEXT: br i1 [[SWITCH]], label [[TRUE:%.*]], label [[FALSE:%.*]]
; CHECK: common.ret:
; CHECK-NEXT: ret void
Expand Down Expand Up @@ -209,8 +208,7 @@ define void @test5(i8 %a) {
; CHECK-SAME: i8 [[A:%.*]]) {
; CHECK-NEXT: [[CMP:%.*]] = icmp ult i8 [[A]], 2
; CHECK-NEXT: call void @llvm.assume(i1 [[CMP]])
; CHECK-NEXT: [[A_OFF:%.*]] = add i8 [[A]], -1
; CHECK-NEXT: [[SWITCH:%.*]] = icmp ult i8 [[A_OFF]], 1
; CHECK-NEXT: [[SWITCH:%.*]] = icmp eq i8 [[A]], 1
; CHECK-NEXT: br i1 [[SWITCH]], label [[TRUE:%.*]], label [[FALSE:%.*]]
; CHECK: common.ret:
; CHECK-NEXT: ret void
Expand Down Expand Up @@ -243,8 +241,7 @@ define void @test6(i8 %a) {
; CHECK-NEXT: [[AND:%.*]] = and i8 [[A]], -2
; CHECK-NEXT: [[CMP:%.*]] = icmp eq i8 [[AND]], -2
; CHECK-NEXT: call void @llvm.assume(i1 [[CMP]])
; CHECK-NEXT: [[A_OFF:%.*]] = add i8 [[A]], 1
; CHECK-NEXT: [[SWITCH:%.*]] = icmp ult i8 [[A_OFF]], 1
; CHECK-NEXT: [[SWITCH:%.*]] = icmp eq i8 [[A]], -1
; CHECK-NEXT: br i1 [[SWITCH]], label [[TRUE:%.*]], label [[FALSE:%.*]]
; CHECK: common.ret:
; CHECK-NEXT: ret void
Expand Down Expand Up @@ -279,8 +276,7 @@ define void @test7(i8 %a) {
; CHECK-NEXT: [[AND:%.*]] = and i8 [[A]], -2
; CHECK-NEXT: [[CMP:%.*]] = icmp eq i8 [[AND]], -2
; CHECK-NEXT: call void @llvm.assume(i1 [[CMP]])
; CHECK-NEXT: [[A_OFF:%.*]] = add i8 [[A]], 1
; CHECK-NEXT: [[SWITCH:%.*]] = icmp ult i8 [[A_OFF]], 1
; CHECK-NEXT: [[SWITCH:%.*]] = icmp eq i8 [[A]], -1
; CHECK-NEXT: br i1 [[SWITCH]], label [[TRUE:%.*]], label [[FALSE:%.*]]
; CHECK: common.ret:
; CHECK-NEXT: ret void
Expand Down
Loading