diff --git a/llvm/lib/Transforms/Utils/SimplifyCFG.cpp b/llvm/lib/Transforms/Utils/SimplifyCFG.cpp index 8bba634521e3e..debfe914ec803 100644 --- a/llvm/lib/Transforms/Utils/SimplifyCFG.cpp +++ b/llvm/lib/Transforms/Utils/SimplifyCFG.cpp @@ -5718,15 +5718,66 @@ bool SimplifyCFGOpt::simplifyUnreachable(UnreachableInst *UI) { return Changed; } -static bool casesAreContiguous(SmallVectorImpl &Cases) { +struct ContiguousCasesResult { + ConstantInt *Min; + ConstantInt *Max; + BasicBlock *Dest; + BasicBlock *OtherDest; + SmallVectorImpl *Cases; + SmallVectorImpl *OtherCases; +}; + +static std::optional +findContiguousCases(Value *Condition, SmallVectorImpl &Cases, + SmallVectorImpl &OtherCases, + BasicBlock *Dest, BasicBlock *OtherDest) { assert(Cases.size() >= 1); array_pod_sort(Cases.begin(), Cases.end(), constantIntSortPredicate); - for (size_t I = 1, E = Cases.size(); I != E; ++I) { - if (Cases[I - 1]->getValue() != Cases[I]->getValue() + 1) - return false; + const APInt &Min = Cases.back()->getValue(); + const APInt &Max = Cases.front()->getValue(); + APInt Offset = Max - Min; + size_t ContiguousOffset = Cases.size() - 1; + if (Offset == ContiguousOffset) { + return ContiguousCasesResult{ + /*Min=*/Cases.back(), + /*Max=*/Cases.front(), + /*Dest=*/Dest, + /*OtherDest=*/OtherDest, + /*Cases=*/&Cases, + /*OtherCases=*/&OtherCases, + }; } - return true; + ConstantRange CR = computeConstantRange(Condition, /*ForSigned=*/false); + // If this is a wrapping contiguous range, that is, [Min, OtherMin] + + // [OtherMax, Max] (also [OtherMax, OtherMin]), [OtherMin+1, OtherMax-1] is a + // contiguous range for the other destination. N.B. If CR is not a full range, + // Max+1 is not equal to Min. It's not continuous in arithmetic. + if (Max == CR.getUnsignedMax() && Min == CR.getUnsignedMin()) { + assert(Cases.size() >= 2); + auto *It = + std::adjacent_find(Cases.begin(), Cases.end(), [](auto L, auto R) { + return L->getValue() != R->getValue() + 1; + }); + if (It == Cases.end()) + return std::nullopt; + auto [OtherMax, OtherMin] = std::make_pair(*It, *std::next(It)); + if ((Max - OtherMax->getValue()) + (OtherMin->getValue() - Min) == + Cases.size() - 2) { + return ContiguousCasesResult{ + /*Min=*/cast( + ConstantInt::get(OtherMin->getType(), OtherMin->getValue() + 1)), + /*Max=*/ + cast( + ConstantInt::get(OtherMax->getType(), OtherMax->getValue() - 1)), + /*Dest=*/OtherDest, + /*OtherDest=*/Dest, + /*Cases=*/&OtherCases, + /*OtherCases=*/&Cases, + }; + } + } + return std::nullopt; } static void createUnreachableSwitchDefault(SwitchInst *Switch, @@ -5763,7 +5814,6 @@ bool SimplifyCFGOpt::turnSwitchRangeIntoICmp(SwitchInst *SI, bool HasDefault = !SI->defaultDestUnreachable(); auto *BB = SI->getParent(); - // Partition the cases into two sets with different destinations. BasicBlock *DestA = HasDefault ? SI->getDefaultDest() : nullptr; BasicBlock *DestB = nullptr; @@ -5797,37 +5847,62 @@ bool SimplifyCFGOpt::turnSwitchRangeIntoICmp(SwitchInst *SI, assert(!CasesA.empty() || HasDefault); // Figure out if one of the sets of cases form a contiguous range. - SmallVectorImpl *ContiguousCases = nullptr; - BasicBlock *ContiguousDest = nullptr; - BasicBlock *OtherDest = nullptr; - if (!CasesA.empty() && casesAreContiguous(CasesA)) { - ContiguousCases = &CasesA; - ContiguousDest = DestA; - OtherDest = DestB; - } else if (casesAreContiguous(CasesB)) { - ContiguousCases = &CasesB; - ContiguousDest = DestB; - OtherDest = DestA; - } else - return false; + std::optional ContiguousCases; + + // Only one icmp is needed when there is only one case. + if (!HasDefault && CasesA.size() == 1) + ContiguousCases = ContiguousCasesResult{ + /*Min=*/CasesA[0], + /*Max=*/CasesA[0], + /*Dest=*/DestA, + /*OtherDest=*/DestB, + /*Cases=*/&CasesA, + /*OtherCases=*/&CasesB, + }; + else if (CasesB.size() == 1) + ContiguousCases = ContiguousCasesResult{ + /*Min=*/CasesB[0], + /*Max=*/CasesB[0], + /*Dest=*/DestB, + /*OtherDest=*/DestA, + /*Cases=*/&CasesB, + /*OtherCases=*/&CasesA, + }; + // Correctness: Cases to the default destination cannot be contiguous cases. + else if (!HasDefault) + ContiguousCases = + findContiguousCases(SI->getCondition(), CasesA, CasesB, DestA, DestB); - // Start building the compare and branch. + if (!ContiguousCases) + ContiguousCases = + findContiguousCases(SI->getCondition(), CasesB, CasesA, DestB, DestA); - Constant *Offset = ConstantExpr::getNeg(ContiguousCases->back()); - Constant *NumCases = - ConstantInt::get(Offset->getType(), ContiguousCases->size()); + if (!ContiguousCases) + return false; - Value *Sub = SI->getCondition(); - if (!Offset->isNullValue()) - Sub = Builder.CreateAdd(Sub, Offset, Sub->getName() + ".off"); + auto [Min, Max, Dest, OtherDest, Cases, OtherCases] = *ContiguousCases; - Value *Cmp; + // Start building the compare and branch. + + Constant *Offset = ConstantExpr::getNeg(Min); + Constant *NumCases = ConstantInt::get(Offset->getType(), + Max->getValue() - Min->getValue() + 1); + BranchInst *NewBI; + if (NumCases->isOneValue()) { + assert(Max->getValue() == Min->getValue()); + Value *Cmp = Builder.CreateICmpEQ(SI->getCondition(), Min); + NewBI = Builder.CreateCondBr(Cmp, Dest, OtherDest); + } // If NumCases overflowed, then all possible values jump to the successor. - if (NumCases->isNullValue() && !ContiguousCases->empty()) - Cmp = ConstantInt::getTrue(SI->getContext()); - else - Cmp = Builder.CreateICmpULT(Sub, NumCases, "switch"); - BranchInst *NewBI = Builder.CreateCondBr(Cmp, ContiguousDest, OtherDest); + else if (NumCases->isNullValue() && !Cases->empty()) { + NewBI = Builder.CreateBr(Dest); + } else { + Value *Sub = SI->getCondition(); + if (!Offset->isNullValue()) + Sub = Builder.CreateAdd(Sub, Offset, Sub->getName() + ".off"); + Value *Cmp = Builder.CreateICmpULT(Sub, NumCases, "switch"); + NewBI = Builder.CreateCondBr(Cmp, Dest, OtherDest); + } // Update weight for the newly-created conditional branch. if (hasBranchWeightMD(*SI)) { @@ -5837,7 +5912,7 @@ bool SimplifyCFGOpt::turnSwitchRangeIntoICmp(SwitchInst *SI, uint64_t TrueWeight = 0; uint64_t FalseWeight = 0; for (size_t I = 0, E = Weights.size(); I != E; ++I) { - if (SI->getSuccessor(I) == ContiguousDest) + if (SI->getSuccessor(I) == Dest) TrueWeight += Weights[I]; else FalseWeight += Weights[I]; @@ -5852,15 +5927,15 @@ bool SimplifyCFGOpt::turnSwitchRangeIntoICmp(SwitchInst *SI, } // Prune obsolete incoming values off the successors' PHI nodes. - for (auto BBI = ContiguousDest->begin(); isa(BBI); ++BBI) { - unsigned PreviousEdges = ContiguousCases->size(); - if (ContiguousDest == SI->getDefaultDest()) + for (auto BBI = Dest->begin(); isa(BBI); ++BBI) { + unsigned PreviousEdges = Cases->size(); + if (Dest == SI->getDefaultDest()) ++PreviousEdges; for (unsigned I = 0, E = PreviousEdges - 1; I != E; ++I) cast(BBI)->removeIncomingValue(SI->getParent()); } for (auto BBI = OtherDest->begin(); isa(BBI); ++BBI) { - unsigned PreviousEdges = SI->getNumCases() - ContiguousCases->size(); + unsigned PreviousEdges = OtherCases->size(); if (OtherDest == SI->getDefaultDest()) ++PreviousEdges; for (unsigned I = 0, E = PreviousEdges - 1; I != E; ++I) diff --git a/llvm/test/Transforms/Coroutines/coro-catchswitch-cleanuppad.ll b/llvm/test/Transforms/Coroutines/coro-catchswitch-cleanuppad.ll index d0e7c1c29eb32..e1e1611ee3362 100644 --- a/llvm/test/Transforms/Coroutines/coro-catchswitch-cleanuppad.ll +++ b/llvm/test/Transforms/Coroutines/coro-catchswitch-cleanuppad.ll @@ -80,8 +80,8 @@ cleanup2: ; CHECK: cleanup2.corodispatch: ; CHECK: %1 = phi i8 [ 0, %handler2 ], [ 1, %catch.dispatch.2 ] ; CHECK: %2 = cleanuppad within %h1 [] -; CHECK: %switch = icmp ult i8 %1, 1 -; CHECK: br i1 %switch, label %cleanup2.from.handler2, label %cleanup2.from.catch.dispatch.2 +; CHECK: %3 = icmp eq i8 %1, 0 +; CHECK: br i1 %3, label %cleanup2.from.handler2, label %cleanup2.from.catch.dispatch.2 ; CHECK: cleanup2.from.handler2: ; CHECK: %valueB.reload = load i32, ptr %valueB.spill.addr, align 4 diff --git a/llvm/test/Transforms/SimplifyCFG/switch-dead-default.ll b/llvm/test/Transforms/SimplifyCFG/switch-dead-default.ll index 4a457cc177e85..a0e29dd19dd84 100644 --- a/llvm/test/Transforms/SimplifyCFG/switch-dead-default.ll +++ b/llvm/test/Transforms/SimplifyCFG/switch-dead-default.ll @@ -7,8 +7,7 @@ declare void @foo(i32) define void @test(i1 %a) { ; CHECK-LABEL: define void @test( ; CHECK-SAME: i1 [[A:%.*]]) { -; CHECK-NEXT: [[A_OFF:%.*]] = add i1 [[A]], true -; CHECK-NEXT: [[SWITCH:%.*]] = icmp ult i1 [[A_OFF]], true +; CHECK-NEXT: [[SWITCH:%.*]] = icmp eq i1 [[A]], true ; CHECK-NEXT: br i1 [[SWITCH]], label [[TRUE:%.*]], label [[FALSE:%.*]] ; CHECK: common.ret: ; CHECK-NEXT: ret void @@ -209,8 +208,7 @@ define void @test5(i8 %a) { ; CHECK-SAME: i8 [[A:%.*]]) { ; CHECK-NEXT: [[CMP:%.*]] = icmp ult i8 [[A]], 2 ; CHECK-NEXT: call void @llvm.assume(i1 [[CMP]]) -; CHECK-NEXT: [[A_OFF:%.*]] = add i8 [[A]], -1 -; CHECK-NEXT: [[SWITCH:%.*]] = icmp ult i8 [[A_OFF]], 1 +; CHECK-NEXT: [[SWITCH:%.*]] = icmp eq i8 [[A]], 1 ; CHECK-NEXT: br i1 [[SWITCH]], label [[TRUE:%.*]], label [[FALSE:%.*]] ; CHECK: common.ret: ; CHECK-NEXT: ret void @@ -243,8 +241,7 @@ define void @test6(i8 %a) { ; CHECK-NEXT: [[AND:%.*]] = and i8 [[A]], -2 ; CHECK-NEXT: [[CMP:%.*]] = icmp eq i8 [[AND]], -2 ; CHECK-NEXT: call void @llvm.assume(i1 [[CMP]]) -; CHECK-NEXT: [[A_OFF:%.*]] = add i8 [[A]], 1 -; CHECK-NEXT: [[SWITCH:%.*]] = icmp ult i8 [[A_OFF]], 1 +; CHECK-NEXT: [[SWITCH:%.*]] = icmp eq i8 [[A]], -1 ; CHECK-NEXT: br i1 [[SWITCH]], label [[TRUE:%.*]], label [[FALSE:%.*]] ; CHECK: common.ret: ; CHECK-NEXT: ret void @@ -279,8 +276,7 @@ define void @test7(i8 %a) { ; CHECK-NEXT: [[AND:%.*]] = and i8 [[A]], -2 ; CHECK-NEXT: [[CMP:%.*]] = icmp eq i8 [[AND]], -2 ; CHECK-NEXT: call void @llvm.assume(i1 [[CMP]]) -; CHECK-NEXT: [[A_OFF:%.*]] = add i8 [[A]], 1 -; CHECK-NEXT: [[SWITCH:%.*]] = icmp ult i8 [[A_OFF]], 1 +; CHECK-NEXT: [[SWITCH:%.*]] = icmp eq i8 [[A]], -1 ; CHECK-NEXT: br i1 [[SWITCH]], label [[TRUE:%.*]], label [[FALSE:%.*]] ; CHECK: common.ret: ; CHECK-NEXT: ret void diff --git a/llvm/test/Transforms/SimplifyCFG/switch-range-to-icmp.ll b/llvm/test/Transforms/SimplifyCFG/switch-range-to-icmp.ll index 8f2ae2d054f1e..0fc3c19edd1f3 100644 --- a/llvm/test/Transforms/SimplifyCFG/switch-range-to-icmp.ll +++ b/llvm/test/Transforms/SimplifyCFG/switch-range-to-icmp.ll @@ -188,4 +188,217 @@ exit: ret void } +define i32 @wrapping_known_range(i8 range(i8 0, 6) %arg) { +; CHECK-LABEL: @wrapping_known_range( +; CHECK-NEXT: [[ARG_OFF:%.*]] = add i8 [[ARG:%.*]], -1 +; CHECK-NEXT: [[SWITCH:%.*]] = icmp ult i8 [[ARG_OFF]], 3 +; CHECK-NEXT: br i1 [[SWITCH]], label [[ELSE:%.*]], label [[IF:%.*]] +; CHECK: common.ret: +; CHECK-NEXT: [[COMMON_RET_OP:%.*]] = phi i32 [ [[I0:%.*]], [[IF]] ], [ [[I1:%.*]], [[ELSE]] ] +; CHECK-NEXT: ret i32 [[COMMON_RET_OP]] +; CHECK: if: +; CHECK-NEXT: [[I0]] = call i32 @f(i32 0) +; CHECK-NEXT: br label [[COMMON_RET:%.*]] +; CHECK: else: +; CHECK-NEXT: [[I1]] = call i32 @f(i32 1) +; CHECK-NEXT: br label [[COMMON_RET]] +; + switch i8 %arg, label %else [ + i8 0, label %if + i8 4, label %if + i8 5, label %if + ] + +if: + %i0 = call i32 @f(i32 0) + ret i32 %i0 + +else: + %i1 = call i32 @f(i32 1) + ret i32 %i1 +} + +define i32 @wrapping_known_range_2(i8 range(i8 0, 6) %arg) { +; CHECK-LABEL: @wrapping_known_range_2( +; CHECK-NEXT: [[SWITCH:%.*]] = icmp eq i8 [[ARG:%.*]], 1 +; CHECK-NEXT: br i1 [[SWITCH]], label [[ELSE:%.*]], label [[IF:%.*]] +; CHECK: common.ret: +; CHECK-NEXT: [[COMMON_RET_OP:%.*]] = phi i32 [ [[I0:%.*]], [[IF]] ], [ [[I1:%.*]], [[ELSE]] ] +; CHECK-NEXT: ret i32 [[COMMON_RET_OP]] +; CHECK: if: +; CHECK-NEXT: [[I0]] = call i32 @f(i32 0) +; CHECK-NEXT: br label [[COMMON_RET:%.*]] +; CHECK: else: +; CHECK-NEXT: [[I1]] = call i32 @f(i32 1) +; CHECK-NEXT: br label [[COMMON_RET]] +; + switch i8 %arg, label %else [ + i8 0, label %if + i8 2, label %if + i8 3, label %if + i8 4, label %if + i8 5, label %if + ] + +if: + %i0 = call i32 @f(i32 0) + ret i32 %i0 + +else: + %i1 = call i32 @f(i32 1) + ret i32 %i1 +} + +define i32 @wrapping_range(i8 %arg) { +; CHECK-LABEL: @wrapping_range( +; CHECK-NEXT: [[ARG_OFF:%.*]] = add i8 [[ARG:%.*]], -1 +; CHECK-NEXT: [[SWITCH:%.*]] = icmp ult i8 [[ARG_OFF]], -4 +; CHECK-NEXT: br i1 [[SWITCH]], label [[ELSE:%.*]], label [[IF:%.*]] +; CHECK: common.ret: +; CHECK-NEXT: [[COMMON_RET_OP:%.*]] = phi i32 [ [[I0:%.*]], [[IF]] ], [ [[I1:%.*]], [[ELSE]] ] +; CHECK-NEXT: ret i32 [[COMMON_RET_OP]] +; CHECK: if: +; CHECK-NEXT: [[I0]] = call i32 @f(i32 0) +; CHECK-NEXT: br label [[COMMON_RET:%.*]] +; CHECK: else: +; CHECK-NEXT: [[I1]] = call i32 @f(i32 1) +; CHECK-NEXT: br label [[COMMON_RET]] +; + switch i8 %arg, label %else [ + i8 0, label %if + i8 -3, label %if + i8 -2, label %if + i8 -1, label %if + ] + +if: + %i0 = call i32 @f(i32 0) + ret i32 %i0 + +else: + %i1 = call i32 @f(i32 1) + ret i32 %i1 +} + +define i8 @wrapping_range_phi(i8 %arg) { +; CHECK-LABEL: @wrapping_range_phi( +; CHECK-NEXT: entry: +; CHECK-NEXT: [[ARG_OFF:%.*]] = add i8 [[ARG:%.*]], -1 +; CHECK-NEXT: [[SWITCH:%.*]] = icmp ult i8 [[ARG_OFF]], -2 +; CHECK-NEXT: [[SPEC_SELECT:%.*]] = select i1 [[SWITCH]], i8 0, i8 1 +; CHECK-NEXT: ret i8 [[SPEC_SELECT]] +; +entry: + switch i8 %arg, label %else [ + i8 0, label %if + i8 -1, label %if + ] + +if: + %i = phi i8 [ 0, %else ], [ 1, %entry ], [ 1, %entry ] + ret i8 %i + +else: + br label %if +} + +define i32 @no_continuous_wrapping_range(i8 %arg) { +; CHECK-LABEL: @no_continuous_wrapping_range( +; CHECK-NEXT: switch i8 [[ARG:%.*]], label [[ELSE:%.*]] [ +; CHECK-NEXT: i8 0, label [[IF:%.*]] +; CHECK-NEXT: i8 -3, label [[IF]] +; CHECK-NEXT: i8 -1, label [[IF]] +; CHECK-NEXT: ] +; CHECK: common.ret: +; CHECK-NEXT: [[COMMON_RET_OP:%.*]] = phi i32 [ [[I0:%.*]], [[IF]] ], [ [[I1:%.*]], [[ELSE]] ] +; CHECK-NEXT: ret i32 [[COMMON_RET_OP]] +; CHECK: if: +; CHECK-NEXT: [[I0]] = call i32 @f(i32 0) +; CHECK-NEXT: br label [[COMMON_RET:%.*]] +; CHECK: else: +; CHECK-NEXT: [[I1]] = call i32 @f(i32 1) +; CHECK-NEXT: br label [[COMMON_RET]] +; + switch i8 %arg, label %else [ + i8 0, label %if + i8 -3, label %if + i8 -1, label %if + ] + +if: + %i0 = call i32 @f(i32 0) + ret i32 %i0 + +else: + %i1 = call i32 @f(i32 1) + ret i32 %i1 +} + +define i32 @one_case_1(i32 %x) { +; CHECK-LABEL: @one_case_1( +; CHECK-NEXT: entry: +; CHECK-NEXT: [[SWITCH:%.*]] = icmp eq i32 [[X:%.*]], 10 +; CHECK-NEXT: br i1 [[SWITCH]], label [[A:%.*]], label [[B:%.*]] +; CHECK: common.ret: +; CHECK-NEXT: [[COMMON_RET_OP:%.*]] = phi i32 [ [[TMP0:%.*]], [[B]] ], [ [[TMP1:%.*]], [[A]] ] +; CHECK-NEXT: ret i32 [[COMMON_RET_OP]] +; CHECK: a: +; CHECK-NEXT: [[TMP0]] = call i32 @f(i32 0) +; CHECK-NEXT: br label [[COMMON_RET:%.*]] +; CHECK: b: +; CHECK-NEXT: [[TMP1]] = call i32 @f(i32 1) +; CHECK-NEXT: br label [[COMMON_RET]] +; +entry: + switch i32 %x, label %unreachable [ + i32 5, label %a + i32 6, label %a + i32 7, label %a + i32 10, label %b + ] + +unreachable: + unreachable +a: + %0 = call i32 @f(i32 0) + ret i32 %0 +b: + %1 = call i32 @f(i32 1) + ret i32 %1 +} + +define i32 @one_case_2(i32 %x) { +; CHECK-LABEL: @one_case_2( +; CHECK-NEXT: entry: +; CHECK-NEXT: [[SWITCH:%.*]] = icmp eq i32 [[X:%.*]], 5 +; CHECK-NEXT: br i1 [[SWITCH]], label [[A:%.*]], label [[B:%.*]] +; CHECK: common.ret: +; CHECK-NEXT: [[COMMON_RET_OP:%.*]] = phi i32 [ [[TMP0:%.*]], [[A]] ], [ [[TMP1:%.*]], [[B]] ] +; CHECK-NEXT: ret i32 [[COMMON_RET_OP]] +; CHECK: a: +; CHECK-NEXT: [[TMP0]] = call i32 @f(i32 0) +; CHECK-NEXT: br label [[COMMON_RET:%.*]] +; CHECK: b: +; CHECK-NEXT: [[TMP1]] = call i32 @f(i32 1) +; CHECK-NEXT: br label [[COMMON_RET]] +; +entry: + switch i32 %x, label %unreachable [ + i32 5, label %a + i32 10, label %b + i32 11, label %b + i32 12, label %b + i32 13, label %b + ] + +unreachable: + unreachable +a: + %0 = call i32 @f(i32 0) + ret i32 %0 +b: + %1 = call i32 @f(i32 1) + ret i32 %1 +} + declare void @bar(ptr nonnull dereferenceable(4))