diff --git a/llvm/include/llvm/Analysis/ScalarEvolutionPatternMatch.h b/llvm/include/llvm/Analysis/ScalarEvolutionPatternMatch.h index 75b2bc1c067ec..7a45ae93b185b 100644 --- a/llvm/include/llvm/Analysis/ScalarEvolutionPatternMatch.h +++ b/llvm/include/llvm/Analysis/ScalarEvolutionPatternMatch.h @@ -183,7 +183,8 @@ m_scev_PtrToInt(const Op0_t &Op0) { } /// Match a binary SCEV. -template +template struct SCEVBinaryExpr_match { Op0_t Op0; Op1_t Op1; @@ -192,15 +193,18 @@ struct SCEVBinaryExpr_match { bool match(const SCEV *S) const { auto *E = dyn_cast(S); - return E && E->getNumOperands() == 2 && Op0.match(E->getOperand(0)) && - Op1.match(E->getOperand(1)); + return E && E->getNumOperands() == 2 && + ((Op0.match(E->getOperand(0)) && Op1.match(E->getOperand(1))) || + (Commutable && Op0.match(E->getOperand(1)) && + Op1.match(E->getOperand(0)))); } }; -template -inline SCEVBinaryExpr_match +template +inline SCEVBinaryExpr_match m_scev_Binary(const Op0_t &Op0, const Op1_t &Op1) { - return SCEVBinaryExpr_match(Op0, Op1); + return SCEVBinaryExpr_match(Op0, Op1); } template @@ -215,6 +219,12 @@ m_scev_Mul(const Op0_t &Op0, const Op1_t &Op1) { return m_scev_Binary(Op0, Op1); } +template +inline SCEVBinaryExpr_match +m_scev_c_Mul(const Op0_t &Op0, const Op1_t &Op1) { + return m_scev_Binary(Op0, Op1); +} + template inline SCEVBinaryExpr_match m_scev_UDiv(const Op0_t &Op0, const Op1_t &Op1) { diff --git a/llvm/lib/Analysis/ScalarEvolution.cpp b/llvm/lib/Analysis/ScalarEvolution.cpp index 09b126d35bde0..54fc495b559d8 100644 --- a/llvm/lib/Analysis/ScalarEvolution.cpp +++ b/llvm/lib/Analysis/ScalarEvolution.cpp @@ -10785,6 +10785,25 @@ bool ScalarEvolution::SimplifyICmpOperands(CmpPredicate &Pred, const SCEV *&LHS, if (Depth >= 3) return false; + const SCEV *NewLHS, *NewRHS; + if (match(LHS, m_scev_c_Mul(m_SCEV(NewLHS), m_SCEVVScale())) && + match(RHS, m_scev_c_Mul(m_SCEV(NewRHS), m_SCEVVScale()))) { + const SCEVMulExpr *LMul = cast(LHS); + const SCEVMulExpr *RMul = cast(RHS); + + // (X * vscale) pred (Y * vscale) ==> X pred Y + // when both multiples are NSW. + // (X * vscale) uicmp/eq/ne (Y * vscale) ==> X uicmp/eq/ne Y + // when both multiples are NUW. + if ((LMul->hasNoSignedWrap() && RMul->hasNoSignedWrap()) || + (LMul->hasNoUnsignedWrap() && RMul->hasNoUnsignedWrap() && + !ICmpInst::isSigned(Pred))) { + LHS = NewLHS; + RHS = NewRHS; + Changed = true; + } + } + // Canonicalize a constant to the right side. if (const SCEVConstant *LHSC = dyn_cast(LHS)) { // Check for both operands constant. @@ -10959,7 +10978,7 @@ bool ScalarEvolution::SimplifyICmpOperands(CmpPredicate &Pred, const SCEV *&LHS, // Recursively simplify until we either hit a recursion limit or nothing // changes. if (Changed) - return SimplifyICmpOperands(Pred, LHS, RHS, Depth + 1); + (void)SimplifyICmpOperands(Pred, LHS, RHS, Depth + 1); return Changed; } diff --git a/llvm/test/Transforms/LoopVectorize/AArch64/sve-vscale-based-trip-counts.ll b/llvm/test/Transforms/LoopVectorize/AArch64/sve-vscale-based-trip-counts.ll index 4444be36c3567..50d5f7ca22de3 100644 --- a/llvm/test/Transforms/LoopVectorize/AArch64/sve-vscale-based-trip-counts.ll +++ b/llvm/test/Transforms/LoopVectorize/AArch64/sve-vscale-based-trip-counts.ll @@ -9,14 +9,14 @@ define void @vscale_mul_4(ptr noalias noundef readonly captures(none) %a, ptr no ; CHECK-NEXT: [[ENTRY:.*]]: ; CHECK-NEXT: [[TMP0:%.*]] = tail call i64 @llvm.vscale.i64() ; CHECK-NEXT: [[TMP1:%.*]] = shl nuw nsw i64 [[TMP0]], 2 -; CHECK-NEXT: [[TMP4:%.*]] = call i64 @llvm.vscale.i64() -; CHECK-NEXT: [[TMP5:%.*]] = mul nuw i64 [[TMP4]], 4 -; CHECK-NEXT: [[N_MOD_VF:%.*]] = urem i64 [[TMP1]], [[TMP5]] +; CHECK-NEXT: [[TMP10:%.*]] = call i64 @llvm.vscale.i64() +; CHECK-NEXT: [[TMP3:%.*]] = mul nuw i64 [[TMP10]], 4 +; CHECK-NEXT: [[N_MOD_VF:%.*]] = urem i64 [[TMP1]], [[TMP3]] ; CHECK-NEXT: [[N_VEC:%.*]] = sub i64 [[TMP1]], [[N_MOD_VF]] ; CHECK-NEXT: [[WIDE_LOAD:%.*]] = load , ptr [[A]], align 4 ; CHECK-NEXT: [[WIDE_LOAD1:%.*]] = load , ptr [[B]], align 4 -; CHECK-NEXT: [[TMP10:%.*]] = fmul [[WIDE_LOAD]], [[WIDE_LOAD1]] -; CHECK-NEXT: store [[TMP10]], ptr [[B]], align 4 +; CHECK-NEXT: [[TMP4:%.*]] = fmul [[WIDE_LOAD]], [[WIDE_LOAD1]] +; CHECK-NEXT: store [[TMP4]], ptr [[B]], align 4 ; CHECK-NEXT: [[CMP_N:%.*]] = icmp eq i64 [[TMP1]], [[N_VEC]] ; CHECK-NEXT: br i1 [[CMP_N]], label %[[FOR_COND_CLEANUP:.*]], label %[[FOR_BODY:.*]] ; CHECK: [[FOR_COND_CLEANUP]]: @@ -124,36 +124,29 @@ define void @vscale_mul_12(ptr noalias noundef readonly captures(none) %a, ptr n ; CHECK-NEXT: [[ENTRY:.*]]: ; CHECK-NEXT: [[TMP0:%.*]] = tail call i64 @llvm.vscale.i64() ; CHECK-NEXT: [[MUL1:%.*]] = mul nuw nsw i64 [[TMP0]], 12 -; CHECK-NEXT: [[TMP2:%.*]] = shl nuw nsw i64 [[TMP0]], 2 -; CHECK-NEXT: [[MIN_ITERS_CHECK:%.*]] = icmp ult i64 [[MUL1]], [[TMP2]] -; CHECK-NEXT: br i1 [[MIN_ITERS_CHECK]], label %[[SCALAR_PH:.*]], label %[[VECTOR_PH:.*]] -; CHECK: [[VECTOR_PH]]: ; CHECK-NEXT: [[TMP3:%.*]] = call i64 @llvm.vscale.i64() ; CHECK-NEXT: [[TMP4:%.*]] = mul nuw i64 [[TMP3]], 4 ; CHECK-NEXT: [[N_MOD_VF:%.*]] = urem i64 [[MUL1]], [[TMP4]] ; CHECK-NEXT: [[N_VEC:%.*]] = sub i64 [[MUL1]], [[N_MOD_VF]] ; CHECK-NEXT: br label %[[VECTOR_BODY:.*]] ; CHECK: [[VECTOR_BODY]]: -; CHECK-NEXT: [[INDEX:%.*]] = phi i64 [ 0, %[[VECTOR_PH]] ], [ [[INDEX_NEXT:%.*]], %[[VECTOR_BODY]] ] +; CHECK-NEXT: [[INDEX:%.*]] = phi i64 [ 0, %[[ENTRY]] ], [ [[INDEX_NEXT:%.*]], %[[VECTOR_BODY]] ] ; CHECK-NEXT: [[TMP7:%.*]] = getelementptr inbounds nuw float, ptr [[A]], i64 [[INDEX]] -; CHECK-NEXT: [[WIDE_LOAD:%.*]] = load , ptr [[TMP7]], align 4 -; CHECK-NEXT: [[TMP9:%.*]] = getelementptr inbounds nuw float, ptr [[B]], i64 [[INDEX]] -; CHECK-NEXT: [[WIDE_LOAD1:%.*]] = load , ptr [[TMP9]], align 4 -; CHECK-NEXT: [[TMP11:%.*]] = fmul [[WIDE_LOAD]], [[WIDE_LOAD1]] -; CHECK-NEXT: store [[TMP11]], ptr [[TMP9]], align 4 +; CHECK-NEXT: [[WIDE_LOAD2:%.*]] = load , ptr [[TMP7]], align 4 +; CHECK-NEXT: [[TMP12:%.*]] = getelementptr inbounds nuw float, ptr [[B]], i64 [[INDEX]] +; CHECK-NEXT: [[WIDE_LOAD4:%.*]] = load , ptr [[TMP12]], align 4 +; CHECK-NEXT: [[TMP25:%.*]] = fmul [[WIDE_LOAD2]], [[WIDE_LOAD4]] +; CHECK-NEXT: store [[TMP25]], ptr [[TMP12]], align 4 ; CHECK-NEXT: [[INDEX_NEXT]] = add nuw i64 [[INDEX]], [[TMP4]] -; CHECK-NEXT: [[TMP12:%.*]] = icmp eq i64 [[INDEX_NEXT]], [[N_VEC]] -; CHECK-NEXT: br i1 [[TMP12]], label %[[MIDDLE_BLOCK:.*]], label %[[VECTOR_BODY]], !llvm.loop [[LOOP4:![0-9]+]] +; CHECK-NEXT: [[TMP22:%.*]] = icmp eq i64 [[INDEX_NEXT]], [[N_VEC]] +; CHECK-NEXT: br i1 [[TMP22]], label %[[MIDDLE_BLOCK:.*]], label %[[VECTOR_BODY]], !llvm.loop [[LOOP4:![0-9]+]] ; CHECK: [[MIDDLE_BLOCK]]: ; CHECK-NEXT: [[CMP_N:%.*]] = icmp eq i64 [[MUL1]], [[N_VEC]] -; CHECK-NEXT: br i1 [[CMP_N]], label %[[FOR_COND_CLEANUP:.*]], label %[[SCALAR_PH]] -; CHECK: [[SCALAR_PH]]: -; CHECK-NEXT: [[BC_RESUME_VAL:%.*]] = phi i64 [ [[N_VEC]], %[[MIDDLE_BLOCK]] ], [ 0, %[[ENTRY]] ] -; CHECK-NEXT: br label %[[FOR_BODY:.*]] +; CHECK-NEXT: br i1 [[CMP_N]], label %[[FOR_COND_CLEANUP:.*]], label %[[FOR_BODY:.*]] ; CHECK: [[FOR_COND_CLEANUP]]: ; CHECK-NEXT: ret void ; CHECK: [[FOR_BODY]]: -; CHECK-NEXT: [[INDVARS_IV:%.*]] = phi i64 [ [[BC_RESUME_VAL]], %[[SCALAR_PH]] ], [ [[INDVARS_IV_NEXT:%.*]], %[[FOR_BODY]] ] +; CHECK-NEXT: [[INDVARS_IV:%.*]] = phi i64 [ [[INDVARS_IV_NEXT:%.*]], %[[FOR_BODY]] ], [ [[N_VEC]], %[[MIDDLE_BLOCK]] ] ; CHECK-NEXT: [[ARRAYIDX:%.*]] = getelementptr inbounds nuw float, ptr [[A]], i64 [[INDVARS_IV]] ; CHECK-NEXT: [[TMP13:%.*]] = load float, ptr [[ARRAYIDX]], align 4 ; CHECK-NEXT: [[ARRAYIDX4:%.*]] = getelementptr inbounds nuw float, ptr [[B]], i64 [[INDVARS_IV]] @@ -191,17 +184,13 @@ define void @vscale_mul_31(ptr noalias noundef readonly captures(none) %a, ptr n ; CHECK-NEXT: [[ENTRY:.*]]: ; CHECK-NEXT: [[TMP0:%.*]] = tail call i64 @llvm.vscale.i64() ; CHECK-NEXT: [[MUL1:%.*]] = mul nuw nsw i64 [[TMP0]], 31 -; CHECK-NEXT: [[TMP2:%.*]] = shl nuw nsw i64 [[TMP0]], 3 -; CHECK-NEXT: [[MIN_ITERS_CHECK:%.*]] = icmp ult i64 [[MUL1]], [[TMP2]] -; CHECK-NEXT: br i1 [[MIN_ITERS_CHECK]], label %[[SCALAR_PH:.*]], label %[[VECTOR_PH:.*]] -; CHECK: [[VECTOR_PH]]: ; CHECK-NEXT: [[TMP3:%.*]] = call i64 @llvm.vscale.i64() ; CHECK-NEXT: [[TMP4:%.*]] = mul nuw i64 [[TMP3]], 8 ; CHECK-NEXT: [[N_MOD_VF:%.*]] = urem i64 [[MUL1]], [[TMP4]] ; CHECK-NEXT: [[N_VEC:%.*]] = sub i64 [[MUL1]], [[N_MOD_VF]] ; CHECK-NEXT: br label %[[VECTOR_BODY:.*]] ; CHECK: [[VECTOR_BODY]]: -; CHECK-NEXT: [[INDEX:%.*]] = phi i64 [ 0, %[[VECTOR_PH]] ], [ [[INDEX_NEXT:%.*]], %[[VECTOR_BODY]] ] +; CHECK-NEXT: [[INDEX:%.*]] = phi i64 [ 0, %[[ENTRY]] ], [ [[INDEX_NEXT:%.*]], %[[VECTOR_BODY]] ] ; CHECK-NEXT: [[TMP7:%.*]] = getelementptr inbounds nuw float, ptr [[A]], i64 [[INDEX]] ; CHECK-NEXT: [[TMP9:%.*]] = call i64 @llvm.vscale.i64() ; CHECK-NEXT: [[TMP10:%.*]] = shl nuw i64 [[TMP9]], 2 @@ -226,14 +215,11 @@ define void @vscale_mul_31(ptr noalias noundef readonly captures(none) %a, ptr n ; CHECK-NEXT: br i1 [[TMP22]], label %[[MIDDLE_BLOCK:.*]], label %[[VECTOR_BODY]], !llvm.loop [[LOOP6:![0-9]+]] ; CHECK: [[MIDDLE_BLOCK]]: ; CHECK-NEXT: [[CMP_N:%.*]] = icmp eq i64 [[MUL1]], [[N_VEC]] -; CHECK-NEXT: br i1 [[CMP_N]], label %[[FOR_COND_CLEANUP:.*]], label %[[SCALAR_PH]] -; CHECK: [[SCALAR_PH]]: -; CHECK-NEXT: [[BC_RESUME_VAL:%.*]] = phi i64 [ [[N_VEC]], %[[MIDDLE_BLOCK]] ], [ 0, %[[ENTRY]] ] -; CHECK-NEXT: br label %[[FOR_BODY:.*]] +; CHECK-NEXT: br i1 [[CMP_N]], label %[[FOR_COND_CLEANUP:.*]], label %[[FOR_BODY:.*]] ; CHECK: [[FOR_COND_CLEANUP]]: ; CHECK-NEXT: ret void ; CHECK: [[FOR_BODY]]: -; CHECK-NEXT: [[INDVARS_IV:%.*]] = phi i64 [ [[BC_RESUME_VAL]], %[[SCALAR_PH]] ], [ [[INDVARS_IV_NEXT:%.*]], %[[FOR_BODY]] ] +; CHECK-NEXT: [[INDVARS_IV:%.*]] = phi i64 [ [[INDVARS_IV_NEXT:%.*]], %[[FOR_BODY]] ], [ [[N_VEC]], %[[MIDDLE_BLOCK]] ] ; CHECK-NEXT: [[ARRAYIDX:%.*]] = getelementptr inbounds nuw float, ptr [[A]], i64 [[INDVARS_IV]] ; CHECK-NEXT: [[TMP23:%.*]] = load float, ptr [[ARRAYIDX]], align 4 ; CHECK-NEXT: [[ARRAYIDX4:%.*]] = getelementptr inbounds nuw float, ptr [[B]], i64 [[INDVARS_IV]] @@ -271,17 +257,13 @@ define void @vscale_mul_64(ptr noalias noundef readonly captures(none) %a, ptr n ; CHECK-NEXT: [[ENTRY:.*]]: ; CHECK-NEXT: [[TMP0:%.*]] = tail call i64 @llvm.vscale.i64() ; CHECK-NEXT: [[MUL1:%.*]] = mul nuw nsw i64 [[TMP0]], 64 -; CHECK-NEXT: [[TMP2:%.*]] = shl nuw nsw i64 [[TMP0]], 3 -; CHECK-NEXT: [[MIN_ITERS_CHECK:%.*]] = icmp ult i64 [[MUL1]], [[TMP2]] -; CHECK-NEXT: br i1 [[MIN_ITERS_CHECK]], label %[[SCALAR_PH:.*]], label %[[VECTOR_PH:.*]] -; CHECK: [[VECTOR_PH]]: ; CHECK-NEXT: [[TMP3:%.*]] = call i64 @llvm.vscale.i64() ; CHECK-NEXT: [[TMP4:%.*]] = mul nuw i64 [[TMP3]], 8 ; CHECK-NEXT: [[N_MOD_VF:%.*]] = urem i64 [[MUL1]], [[TMP4]] ; CHECK-NEXT: [[N_VEC:%.*]] = sub i64 [[MUL1]], [[N_MOD_VF]] ; CHECK-NEXT: br label %[[VECTOR_BODY:.*]] ; CHECK: [[VECTOR_BODY]]: -; CHECK-NEXT: [[INDEX:%.*]] = phi i64 [ 0, %[[VECTOR_PH]] ], [ [[INDEX_NEXT:%.*]], %[[VECTOR_BODY]] ] +; CHECK-NEXT: [[INDEX:%.*]] = phi i64 [ 0, %[[ENTRY]] ], [ [[INDEX_NEXT:%.*]], %[[VECTOR_BODY]] ] ; CHECK-NEXT: [[TMP7:%.*]] = getelementptr inbounds nuw float, ptr [[A]], i64 [[INDEX]] ; CHECK-NEXT: [[TMP9:%.*]] = call i64 @llvm.vscale.i64() ; CHECK-NEXT: [[TMP10:%.*]] = shl nuw i64 [[TMP9]], 2 @@ -306,14 +288,11 @@ define void @vscale_mul_64(ptr noalias noundef readonly captures(none) %a, ptr n ; CHECK-NEXT: br i1 [[TMP22]], label %[[MIDDLE_BLOCK:.*]], label %[[VECTOR_BODY]], !llvm.loop [[LOOP8:![0-9]+]] ; CHECK: [[MIDDLE_BLOCK]]: ; CHECK-NEXT: [[CMP_N:%.*]] = icmp eq i64 [[MUL1]], [[N_VEC]] -; CHECK-NEXT: br i1 [[CMP_N]], label %[[FOR_COND_CLEANUP:.*]], label %[[SCALAR_PH]] -; CHECK: [[SCALAR_PH]]: -; CHECK-NEXT: [[BC_RESUME_VAL:%.*]] = phi i64 [ [[N_VEC]], %[[MIDDLE_BLOCK]] ], [ 0, %[[ENTRY]] ] -; CHECK-NEXT: br label %[[FOR_BODY:.*]] +; CHECK-NEXT: br i1 [[CMP_N]], label %[[FOR_COND_CLEANUP:.*]], label %[[FOR_BODY:.*]] ; CHECK: [[FOR_COND_CLEANUP]]: ; CHECK-NEXT: ret void ; CHECK: [[FOR_BODY]]: -; CHECK-NEXT: [[INDVARS_IV:%.*]] = phi i64 [ [[BC_RESUME_VAL]], %[[SCALAR_PH]] ], [ [[INDVARS_IV_NEXT:%.*]], %[[FOR_BODY]] ] +; CHECK-NEXT: [[INDVARS_IV:%.*]] = phi i64 [ [[INDVARS_IV_NEXT:%.*]], %[[FOR_BODY]] ], [ [[N_VEC]], %[[MIDDLE_BLOCK]] ] ; CHECK-NEXT: [[ARRAYIDX:%.*]] = getelementptr inbounds nuw float, ptr [[A]], i64 [[INDVARS_IV]] ; CHECK-NEXT: [[TMP23:%.*]] = load float, ptr [[ARRAYIDX]], align 4 ; CHECK-NEXT: [[ARRAYIDX4:%.*]] = getelementptr inbounds nuw float, ptr [[B]], i64 [[INDVARS_IV]] diff --git a/llvm/unittests/Analysis/ScalarEvolutionTest.cpp b/llvm/unittests/Analysis/ScalarEvolutionTest.cpp index 678960418d7d7..1a68823b4f254 100644 --- a/llvm/unittests/Analysis/ScalarEvolutionTest.cpp +++ b/llvm/unittests/Analysis/ScalarEvolutionTest.cpp @@ -1768,4 +1768,141 @@ TEST_F(ScalarEvolutionsTest, ComplexityComparatorIsStrictWeakOrdering3) { SE.getSCEV(Or1); } +TEST_F(ScalarEvolutionsTest, SimplifyICmpOperands) { + LLVMContext C; + SMDiagnostic Err; + std::unique_ptr M = + parseAssemblyString("define i32 @foo(ptr %loc, i32 %a, i32 %b) {" + "entry: " + " ret i32 %a " + "} ", + Err, C); + + ASSERT_TRUE(M && "Could not parse module?"); + ASSERT_TRUE(!verifyModule(*M) && "Must have been well formed!"); + + // Remove common factor when there's no signed wrapping. + runWithSE(*M, "foo", [](Function &F, LoopInfo &LI, ScalarEvolution &SE) { + const SCEV *A = SE.getSCEV(getArgByName(F, "a")); + const SCEV *B = SE.getSCEV(getArgByName(F, "b")); + const SCEV *VS = SE.getVScale(A->getType()); + const SCEV *VSxA = SE.getMulExpr(VS, A, SCEV::FlagNSW); + const SCEV *VSxB = SE.getMulExpr(VS, B, SCEV::FlagNSW); + + { + CmpPredicate NewPred = ICmpInst::ICMP_SLT; + const SCEV *NewLHS = VSxA; + const SCEV *NewRHS = VSxB; + EXPECT_TRUE(SE.SimplifyICmpOperands(NewPred, NewLHS, NewRHS)); + EXPECT_EQ(NewPred, ICmpInst::ICMP_SLT); + EXPECT_EQ(NewLHS, A); + EXPECT_EQ(NewRHS, B); + } + + { + CmpPredicate NewPred = ICmpInst::ICMP_ULT; + const SCEV *NewLHS = VSxA; + const SCEV *NewRHS = VSxB; + EXPECT_TRUE(SE.SimplifyICmpOperands(NewPred, NewLHS, NewRHS)); + EXPECT_EQ(NewPred, ICmpInst::ICMP_ULT); + EXPECT_EQ(NewLHS, A); + EXPECT_EQ(NewRHS, B); + } + + { + CmpPredicate NewPred = ICmpInst::ICMP_EQ; + const SCEV *NewLHS = VSxA; + const SCEV *NewRHS = VSxB; + EXPECT_TRUE(SE.SimplifyICmpOperands(NewPred, NewLHS, NewRHS)); + EXPECT_EQ(NewPred, ICmpInst::ICMP_EQ); + EXPECT_EQ(NewLHS, A); + EXPECT_EQ(NewRHS, B); + } + + // Verify the common factor's position doesn't impede simplification. + { + const SCEV *C = SE.getConstant(A->getType(), 100); + const SCEV *CxVS = SE.getMulExpr(C, VS, SCEV::FlagNSW); + + // Verify common factor is available at different indices. + ASSERT_TRUE(isa(cast(VSxA)->getOperand(0)) != + isa(cast(CxVS)->getOperand(0))); + + CmpPredicate NewPred = ICmpInst::ICMP_SLT; + const SCEV *NewLHS = VSxA; + const SCEV *NewRHS = CxVS; + EXPECT_TRUE(SE.SimplifyICmpOperands(NewPred, NewLHS, NewRHS)); + EXPECT_EQ(NewPred, ICmpInst::ICMP_SLT); + EXPECT_EQ(NewLHS, A); + EXPECT_EQ(NewRHS, C); + } + }); + + // Remove common factor when there's no unsigned wrapping. + runWithSE(*M, "foo", [](Function &F, LoopInfo &LI, ScalarEvolution &SE) { + const SCEV *A = SE.getSCEV(getArgByName(F, "a")); + const SCEV *B = SE.getSCEV(getArgByName(F, "b")); + const SCEV *VS = SE.getVScale(A->getType()); + const SCEV *VSxA = SE.getMulExpr(VS, A, SCEV::FlagNUW); + const SCEV *VSxB = SE.getMulExpr(VS, B, SCEV::FlagNUW); + + { + CmpPredicate NewPred = ICmpInst::ICMP_SLT; + const SCEV *NewLHS = VSxA; + const SCEV *NewRHS = VSxB; + EXPECT_FALSE(SE.SimplifyICmpOperands(NewPred, NewLHS, NewRHS)); + } + + { + CmpPredicate NewPred = ICmpInst::ICMP_ULT; + const SCEV *NewLHS = VSxA; + const SCEV *NewRHS = VSxB; + EXPECT_TRUE(SE.SimplifyICmpOperands(NewPred, NewLHS, NewRHS)); + EXPECT_EQ(NewPred, ICmpInst::ICMP_ULT); + EXPECT_EQ(NewLHS, A); + EXPECT_EQ(NewRHS, B); + } + + { + CmpPredicate NewPred = ICmpInst::ICMP_EQ; + const SCEV *NewLHS = VSxA; + const SCEV *NewRHS = VSxB; + EXPECT_TRUE(SE.SimplifyICmpOperands(NewPred, NewLHS, NewRHS)); + EXPECT_EQ(NewPred, ICmpInst::ICMP_EQ); + EXPECT_EQ(NewLHS, A); + EXPECT_EQ(NewRHS, B); + } + }); + + // Do not remove common factor due to wrap flag mismatch. + runWithSE(*M, "foo", [](Function &F, LoopInfo &LI, ScalarEvolution &SE) { + const SCEV *A = SE.getSCEV(getArgByName(F, "a")); + const SCEV *B = SE.getSCEV(getArgByName(F, "b")); + const SCEV *VS = SE.getVScale(A->getType()); + const SCEV *VSxA = SE.getMulExpr(VS, A, SCEV::FlagNSW); + const SCEV *VSxB = SE.getMulExpr(VS, B, SCEV::FlagNUW); + + { + CmpPredicate NewPred = ICmpInst::ICMP_SLT; + const SCEV *NewLHS = VSxA; + const SCEV *NewRHS = VSxB; + EXPECT_FALSE(SE.SimplifyICmpOperands(NewPred, NewLHS, NewRHS)); + } + + { + CmpPredicate NewPred = ICmpInst::ICMP_ULT; + const SCEV *NewLHS = VSxA; + const SCEV *NewRHS = VSxB; + EXPECT_FALSE(SE.SimplifyICmpOperands(NewPred, NewLHS, NewRHS)); + } + + { + CmpPredicate NewPred = ICmpInst::ICMP_EQ; + const SCEV *NewLHS = VSxA; + const SCEV *NewRHS = VSxB; + EXPECT_FALSE(SE.SimplifyICmpOperands(NewPred, NewLHS, NewRHS)); + } + }); +} + } // end namespace llvm