diff --git a/llvm/include/llvm/Analysis/ScalarEvolution.h b/llvm/include/llvm/Analysis/ScalarEvolution.h index 8876e4ed6ae4f..76800d43828c2 100644 --- a/llvm/include/llvm/Analysis/ScalarEvolution.h +++ b/llvm/include/llvm/Analysis/ScalarEvolution.h @@ -763,6 +763,10 @@ class ScalarEvolution { getUMinFromMismatchedTypes(SmallVectorImpl &Ops, bool Sequential = false); + /// Try to match the pattern generated by getURemExpr(A, B). If successful, + /// Assign A and B to LHS and RHS, respectively. + LLVM_ABI bool matchURem(const SCEV *Expr, const SCEV *&LHS, const SCEV *&RHS); + /// Transitively follow the chain of pointer-type operands until reaching a /// SCEV that does not have a single pointer operand. This returns a /// SCEVUnknown pointer for well-formed pointer-type expressions, but corner @@ -2316,10 +2320,6 @@ class ScalarEvolution { /// an add rec on said loop. void getUsedLoops(const SCEV *S, SmallPtrSetImpl &LoopsUsed); - /// Try to match the pattern generated by getURemExpr(A, B). If successful, - /// Assign A and B to LHS and RHS, respectively. - LLVM_ABI bool matchURem(const SCEV *Expr, const SCEV *&LHS, const SCEV *&RHS); - /// Look for a SCEV expression with type `SCEVType` and operands `Ops` in /// `UniqueSCEVs`. Return if found, else nullptr. SCEV *findExistingSCEVInCache(SCEVTypes SCEVType, ArrayRef Ops); diff --git a/llvm/lib/Analysis/ScalarEvolution.cpp b/llvm/lib/Analysis/ScalarEvolution.cpp index 30bcff7c14923..656ab8d0bdb1a 100644 --- a/llvm/lib/Analysis/ScalarEvolution.cpp +++ b/llvm/lib/Analysis/ScalarEvolution.cpp @@ -15557,6 +15557,123 @@ void ScalarEvolution::LoopGuards::collectFromPHI( } } +// Checks whether Expr is a non-negative constant, and Divisor is a positive +// constant, and returns their APInt in ExprVal and in DivisorVal. +static bool getNonNegExprAndPosDivisor(const SCEV *Expr, const SCEV *Divisor, + APInt &ExprVal, APInt &DivisorVal) { + auto *ConstExpr = dyn_cast(Expr); + auto *ConstDivisor = dyn_cast(Divisor); + if (!ConstExpr || !ConstDivisor) + return false; + ExprVal = ConstExpr->getAPInt(); + DivisorVal = ConstDivisor->getAPInt(); + return ExprVal.isNonNegative() && !DivisorVal.isNonPositive(); +} + +// Return a new SCEV that modifies \p Expr to the closest number divisible by +// \p Divisor and less than or equal to Expr. +// For now, only handle constant Expr and Divisor. +static const SCEV *getPreviousSCEVDivisibleByDivisor(const SCEV *Expr, + const SCEV *Divisor, + ScalarEvolution &SE) { + APInt ExprVal; + APInt DivisorVal; + if (!getNonNegExprAndPosDivisor(Expr, Divisor, ExprVal, DivisorVal)) + return Expr; + APInt Rem = ExprVal.urem(DivisorVal); + // return the SCEV: Expr - Expr % Divisor + return SE.getConstant(ExprVal - Rem); +} + +// Return a new SCEV that modifies \p Expr to the closest number divisible by +// \p Divisor and greater than or equal to Expr. +// For now, only handle constant Expr and Divisor. +static const SCEV *getNextSCEVDivisibleByDivisor(const SCEV *Expr, + const SCEV *Divisor, + ScalarEvolution &SE) { + APInt ExprVal; + APInt DivisorVal; + if (!getNonNegExprAndPosDivisor(Expr, Divisor, ExprVal, DivisorVal)) + return Expr; + APInt Rem = ExprVal.urem(DivisorVal); + if (!Rem.isZero()) + // return the SCEV: Expr + Divisor - Expr % Divisor + return SE.getConstant(ExprVal + DivisorVal - Rem); + return Expr; +} + +static bool collectDivisibilityInformation( + ICmpInst::Predicate Predicate, const SCEV *LHS, const SCEV *RHS, + DenseMap &DivInfo, + DenseMap &Multiples, ScalarEvolution &SE) { + // If we have LHS == 0, check if LHS is computing a property of some unknown + // SCEV %v which we can rewrite %v to express explicitly. + if (Predicate != CmpInst::ICMP_EQ || !match(RHS, m_scev_Zero())) + return false; + // If LHS is A % B, i.e. A % B == 0, rewrite A to (A /u B) * B to + // explicitly express that. + const SCEV *URemLHS = nullptr; + const SCEV *URemRHS = nullptr; + if (!SE.matchURem(LHS, URemLHS, URemRHS)) + return false; + if (const SCEVUnknown *LHSUnknown = dyn_cast(URemLHS)) { + const auto *Multiple = SE.getMulExpr(SE.getUDivExpr(LHS, URemRHS), URemRHS); + DivInfo[LHSUnknown] = Multiple; + Multiples[LHSUnknown] = URemRHS; + return true; + } + return false; +} + +// Check if the condition is a divisibility guard (A % B == 0). +static bool isDivisibilityGuard(const SCEV *LHS, const SCEV *RHS, + ScalarEvolution &SE) { + const SCEV *X, *Y; + return SE.matchURem(LHS, X, Y) && RHS->isZero(); +} + +// Apply divisibility by \p Divisor on MinMaxExpr with constant values, +// recursively. This is done by aligning up/down the constant value to the +// Divisor. +static const SCEV *applyDivisibilityOnMinMaxExpr(const SCEV *MinMaxExpr, + const SCEV *Divisor, + ScalarEvolution &SE) { + // Return true if \p Expr is a MinMax SCEV expression with a non-negative + // constant operand. If so, return in \p SCTy the SCEV type and in \p RHS + // the non-constant operand and in \p LHS the constant operand. + auto IsMinMaxSCEVWithNonNegativeConstant = + [](const SCEV *Expr, SCEVTypes &SCTy, const SCEV *&LHS, + const SCEV *&RHS) { + if (auto *MinMax = dyn_cast(Expr)) { + if (MinMax->getNumOperands() != 2) + return false; + if (auto *C = dyn_cast(MinMax->getOperand(0))) { + if (C->getAPInt().isNegative()) + return false; + SCTy = MinMax->getSCEVType(); + LHS = MinMax->getOperand(0); + RHS = MinMax->getOperand(1); + return true; + } + } + return false; + }; + + const SCEV *MinMaxLHS = nullptr, *MinMaxRHS = nullptr; + SCEVTypes SCTy; + if (!IsMinMaxSCEVWithNonNegativeConstant(MinMaxExpr, SCTy, MinMaxLHS, + MinMaxRHS)) + return MinMaxExpr; + auto IsMin = isa(MinMaxExpr) || isa(MinMaxExpr); + assert(SE.isKnownNonNegative(MinMaxLHS) && "Expected non-negative operand!"); + auto *DivisibleExpr = + IsMin ? getPreviousSCEVDivisibleByDivisor(MinMaxLHS, Divisor, SE) + : getNextSCEVDivisibleByDivisor(MinMaxLHS, Divisor, SE); + SmallVector Ops = { + applyDivisibilityOnMinMaxExpr(MinMaxRHS, Divisor, SE), DivisibleExpr}; + return SE.getMinMaxExpr(SCTy, Ops); +} + void ScalarEvolution::LoopGuards::collectFromBlock( ScalarEvolution &SE, ScalarEvolution::LoopGuards &Guards, const BasicBlock *Block, const BasicBlock *Pred, @@ -15567,19 +15684,14 @@ void ScalarEvolution::LoopGuards::collectFromBlock( SmallVector ExprsToRewrite; auto CollectCondition = [&](ICmpInst::Predicate Predicate, const SCEV *LHS, const SCEV *RHS, - DenseMap - &RewriteMap) { + DenseMap &RewriteMap, + const DenseMap + &DivInfo) { // WARNING: It is generally unsound to apply any wrap flags to the proposed // replacement SCEV which isn't directly implied by the structure of that // SCEV. In particular, using contextual facts to imply flags is *NOT* // legal. See the scoping rules for flags in the header to understand why. - // If LHS is a constant, apply information to the other expression. - if (isa(LHS)) { - std::swap(LHS, RHS); - Predicate = CmpInst::getSwappedPredicate(Predicate); - } - // Check for a condition of the form (-C1 + X < C2). InstCombine will // create this form when combining two checks of the form (X u< C2 + C1) and // (X >=u C1). @@ -15612,115 +15724,6 @@ void ScalarEvolution::LoopGuards::collectFromBlock( if (MatchRangeCheckIdiom()) return; - // Return true if \p Expr is a MinMax SCEV expression with a non-negative - // constant operand. If so, return in \p SCTy the SCEV type and in \p RHS - // the non-constant operand and in \p LHS the constant operand. - auto IsMinMaxSCEVWithNonNegativeConstant = - [&](const SCEV *Expr, SCEVTypes &SCTy, const SCEV *&LHS, - const SCEV *&RHS) { - if (auto *MinMax = dyn_cast(Expr)) { - if (MinMax->getNumOperands() != 2) - return false; - if (auto *C = dyn_cast(MinMax->getOperand(0))) { - if (C->getAPInt().isNegative()) - return false; - SCTy = MinMax->getSCEVType(); - LHS = MinMax->getOperand(0); - RHS = MinMax->getOperand(1); - return true; - } - } - return false; - }; - - // Checks whether Expr is a non-negative constant, and Divisor is a positive - // constant, and returns their APInt in ExprVal and in DivisorVal. - auto GetNonNegExprAndPosDivisor = [&](const SCEV *Expr, const SCEV *Divisor, - APInt &ExprVal, APInt &DivisorVal) { - auto *ConstExpr = dyn_cast(Expr); - auto *ConstDivisor = dyn_cast(Divisor); - if (!ConstExpr || !ConstDivisor) - return false; - ExprVal = ConstExpr->getAPInt(); - DivisorVal = ConstDivisor->getAPInt(); - return ExprVal.isNonNegative() && !DivisorVal.isNonPositive(); - }; - - // Return a new SCEV that modifies \p Expr to the closest number divides by - // \p Divisor and greater or equal than Expr. - // For now, only handle constant Expr and Divisor. - auto GetNextSCEVDividesByDivisor = [&](const SCEV *Expr, - const SCEV *Divisor) { - APInt ExprVal; - APInt DivisorVal; - if (!GetNonNegExprAndPosDivisor(Expr, Divisor, ExprVal, DivisorVal)) - return Expr; - APInt Rem = ExprVal.urem(DivisorVal); - if (!Rem.isZero()) - // return the SCEV: Expr + Divisor - Expr % Divisor - return SE.getConstant(ExprVal + DivisorVal - Rem); - return Expr; - }; - - // Return a new SCEV that modifies \p Expr to the closest number divides by - // \p Divisor and less or equal than Expr. - // For now, only handle constant Expr and Divisor. - auto GetPreviousSCEVDividesByDivisor = [&](const SCEV *Expr, - const SCEV *Divisor) { - APInt ExprVal; - APInt DivisorVal; - if (!GetNonNegExprAndPosDivisor(Expr, Divisor, ExprVal, DivisorVal)) - return Expr; - APInt Rem = ExprVal.urem(DivisorVal); - // return the SCEV: Expr - Expr % Divisor - return SE.getConstant(ExprVal - Rem); - }; - - // Apply divisibilty by \p Divisor on MinMaxExpr with constant values, - // recursively. This is done by aligning up/down the constant value to the - // Divisor. - std::function - ApplyDivisibiltyOnMinMaxExpr = [&](const SCEV *MinMaxExpr, - const SCEV *Divisor) { - const SCEV *MinMaxLHS = nullptr, *MinMaxRHS = nullptr; - SCEVTypes SCTy; - if (!IsMinMaxSCEVWithNonNegativeConstant(MinMaxExpr, SCTy, MinMaxLHS, - MinMaxRHS)) - return MinMaxExpr; - auto IsMin = - isa(MinMaxExpr) || isa(MinMaxExpr); - assert(SE.isKnownNonNegative(MinMaxLHS) && - "Expected non-negative operand!"); - auto *DivisibleExpr = - IsMin ? GetPreviousSCEVDividesByDivisor(MinMaxLHS, Divisor) - : GetNextSCEVDividesByDivisor(MinMaxLHS, Divisor); - SmallVector Ops = { - ApplyDivisibiltyOnMinMaxExpr(MinMaxRHS, Divisor), DivisibleExpr}; - return SE.getMinMaxExpr(SCTy, Ops); - }; - - // If we have LHS == 0, check if LHS is computing a property of some unknown - // SCEV %v which we can rewrite %v to express explicitly. - if (Predicate == CmpInst::ICMP_EQ && match(RHS, m_scev_Zero())) { - // If LHS is A % B, i.e. A % B == 0, rewrite A to (A /u B) * B to - // explicitly express that. - const SCEV *URemLHS = nullptr; - const SCEV *URemRHS = nullptr; - if (SE.matchURem(LHS, URemLHS, URemRHS)) { - if (const SCEVUnknown *LHSUnknown = dyn_cast(URemLHS)) { - auto I = RewriteMap.find(LHSUnknown); - const SCEV *RewrittenLHS = - I != RewriteMap.end() ? I->second : LHSUnknown; - RewrittenLHS = ApplyDivisibiltyOnMinMaxExpr(RewrittenLHS, URemRHS); - const auto *Multiple = - SE.getMulExpr(SE.getUDivExpr(RewrittenLHS, URemRHS), URemRHS); - RewriteMap[LHSUnknown] = Multiple; - ExprsToRewrite.push_back(LHSUnknown); - return; - } - } - } - // Do not apply information for constants or if RHS contains an AddRec. if (isa(LHS) || SE.containsAddRecurrence(RHS)) return; @@ -15751,7 +15754,11 @@ void ScalarEvolution::LoopGuards::collectFromBlock( const SCEV *RewrittenLHS = GetMaybeRewritten(LHS); const SCEV *DividesBy = nullptr; - const APInt &Multiple = SE.getConstantMultiple(RewrittenLHS); + // Apply divisibility information when computing the constant multiple. + LoopGuards DivGuards(SE); + DivGuards.RewriteMap = DivInfo; + const APInt &Multiple = + SE.getConstantMultiple(DivGuards.rewrite(RewrittenLHS)); if (!Multiple.isOne()) DividesBy = SE.getConstant(Multiple); @@ -15775,21 +15782,23 @@ void ScalarEvolution::LoopGuards::collectFromBlock( [[fallthrough]]; case CmpInst::ICMP_SLT: { RHS = SE.getMinusSCEV(RHS, One); - RHS = DividesBy ? GetPreviousSCEVDividesByDivisor(RHS, DividesBy) : RHS; + RHS = DividesBy ? getPreviousSCEVDivisibleByDivisor(RHS, DividesBy, SE) + : RHS; break; } case CmpInst::ICMP_UGT: case CmpInst::ICMP_SGT: RHS = SE.getAddExpr(RHS, One); - RHS = DividesBy ? GetNextSCEVDividesByDivisor(RHS, DividesBy) : RHS; + RHS = DividesBy ? getNextSCEVDivisibleByDivisor(RHS, DividesBy, SE) : RHS; break; case CmpInst::ICMP_ULE: case CmpInst::ICMP_SLE: - RHS = DividesBy ? GetPreviousSCEVDividesByDivisor(RHS, DividesBy) : RHS; + RHS = DividesBy ? getPreviousSCEVDivisibleByDivisor(RHS, DividesBy, SE) + : RHS; break; case CmpInst::ICMP_UGE: case CmpInst::ICMP_SGE: - RHS = DividesBy ? GetNextSCEVDividesByDivisor(RHS, DividesBy) : RHS; + RHS = DividesBy ? getNextSCEVDivisibleByDivisor(RHS, DividesBy, SE) : RHS; break; default: break; @@ -15843,7 +15852,8 @@ void ScalarEvolution::LoopGuards::collectFromBlock( case CmpInst::ICMP_NE: if (match(RHS, m_scev_Zero())) { const SCEV *OneAlignedUp = - DividesBy ? GetNextSCEVDividesByDivisor(One, DividesBy) : One; + DividesBy ? getNextSCEVDivisibleByDivisor(One, DividesBy, SE) + : One; To = SE.getUMaxExpr(FromRewritten, OneAlignedUp); } break; @@ -15916,8 +15926,11 @@ void ScalarEvolution::LoopGuards::collectFromBlock( // Now apply the information from the collected conditions to // Guards.RewriteMap. Conditions are processed in reverse order, so the - // earliest conditions is processed first. This ensures the SCEVs with the + // earliest conditions is processed first, except guards with divisibility + // information, which are moved to the back. This ensures the SCEVs with the // shortest dependency chains are constructed first. + SmallVector> + GuardsToProcess; for (auto [Term, EnterIfTrue] : reverse(Terms)) { SmallVector Worklist; SmallPtrSet Visited; @@ -15932,7 +15945,12 @@ void ScalarEvolution::LoopGuards::collectFromBlock( EnterIfTrue ? Cmp->getPredicate() : Cmp->getInversePredicate(); const auto *LHS = SE.getSCEV(Cmp->getOperand(0)); const auto *RHS = SE.getSCEV(Cmp->getOperand(1)); - CollectCondition(Predicate, LHS, RHS, Guards.RewriteMap); + // If LHS is a constant, apply information to the other expression. + if (isa(LHS)) { + std::swap(LHS, RHS); + Predicate = CmpInst::getSwappedPredicate(Predicate); + } + GuardsToProcess.emplace_back(Predicate, LHS, RHS); continue; } @@ -15945,6 +15963,28 @@ void ScalarEvolution::LoopGuards::collectFromBlock( } } + // Process divisibility guards in reverse order to populate DivInfo early. + DenseMap Multiples; + DenseMap DivInfo; + for (const auto &[Predicate, LHS, RHS] : GuardsToProcess) { + if (!isDivisibilityGuard(LHS, RHS, SE)) + continue; + collectDivisibilityInformation(Predicate, LHS, RHS, DivInfo, Multiples, SE); + } + + for (const auto &[Predicate, LHS, RHS] : GuardsToProcess) + CollectCondition(Predicate, LHS, RHS, Guards.RewriteMap, DivInfo); + + // Apply divisibility information last. This ensures it is applied to the + // outermost expression after other rewrites for the given value. + for (const auto &[K, V] : Multiples) { + Guards.RewriteMap[K] = SE.getMulExpr( + SE.getUDivExpr(applyDivisibilityOnMinMaxExpr(Guards.rewrite(K), V, SE), + V), + V); + ExprsToRewrite.push_back(K); + } + // Let the rewriter preserve NUW/NSW flags if the unsigned/signed ranges of // the replacement expressions are contained in the ranges of the replaced // expressions. diff --git a/llvm/test/Transforms/IndVarSimplify/loop-guard-order.ll b/llvm/test/Transforms/IndVarSimplify/loop-guard-order.ll index 14ee00d77197c..2763860e79875 100644 --- a/llvm/test/Transforms/IndVarSimplify/loop-guard-order.ll +++ b/llvm/test/Transforms/IndVarSimplify/loop-guard-order.ll @@ -114,7 +114,7 @@ define i32 @urem_order1(i32 %n) { ; CHECK: [[LOOP]]: ; CHECK-NEXT: [[IV:%.*]] = phi i32 [ [[IV_NEXT:%.*]], %[[LOOP]] ], [ 0, %[[LOOP_PREHEADER]] ] ; CHECK-NEXT: call void @foo() -; CHECK-NEXT: [[IV_NEXT]] = add i32 [[IV]], 3 +; CHECK-NEXT: [[IV_NEXT]] = add nuw i32 [[IV]], 3 ; CHECK-NEXT: [[EC:%.*]] = icmp eq i32 [[IV_NEXT]], [[N]] ; CHECK-NEXT: br i1 [[EC]], label %[[EXIT_LOOPEXIT:.*]], label %[[LOOP]] ; CHECK: [[EXIT_LOOPEXIT]]: @@ -205,13 +205,12 @@ define i64 @test_loop_with_div_order_1(i64 %n) { ; CHECK-NEXT: [[PARITY_CHECK:%.*]] = icmp eq i64 [[IS_ODD]], 0 ; CHECK-NEXT: br i1 [[PARITY_CHECK]], label %[[LOOP_PREHEADER:.*]], label %[[EXIT]] ; CHECK: [[LOOP_PREHEADER]]: -; CHECK-NEXT: [[UMAX:%.*]] = call i64 @llvm.umax.i64(i64 [[UPPER_BOUND]], i64 1) ; CHECK-NEXT: br label %[[LOOP:.*]] ; CHECK: [[LOOP]]: ; CHECK-NEXT: [[IV:%.*]] = phi i64 [ [[IV_NEXT:%.*]], %[[LOOP]] ], [ 0, %[[LOOP_PREHEADER]] ] ; CHECK-NEXT: [[DUMMY:%.*]] = load volatile i64, ptr null, align 8 ; CHECK-NEXT: [[IV_NEXT]] = add nuw nsw i64 [[IV]], 1 -; CHECK-NEXT: [[EXITCOND:%.*]] = icmp ne i64 [[IV_NEXT]], [[UMAX]] +; CHECK-NEXT: [[EXITCOND:%.*]] = icmp ne i64 [[IV_NEXT]], [[UPPER_BOUND]] ; CHECK-NEXT: br i1 [[EXITCOND]], label %[[LOOP]], label %[[EXIT_LOOPEXIT:.*]] ; CHECK: [[EXIT_LOOPEXIT]]: ; CHECK-NEXT: br label %[[EXIT]]