Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 0 additions & 4 deletions llvm/include/llvm/Analysis/ScalarEvolution.h
Original file line number Diff line number Diff line change
Expand Up @@ -2316,10 +2316,6 @@ class ScalarEvolution {
/// an add rec on said loop.
void getUsedLoops(const SCEV *S, SmallPtrSetImpl<const Loop *> &LoopsUsed);

/// Try to match the pattern generated by getURemExpr(A, B). If successful,
/// Assign A and B to LHS and RHS, respectively.
LLVM_ABI bool matchURem(const SCEV *Expr, const SCEV *&LHS, const SCEV *&RHS);

/// Look for a SCEV expression with type `SCEVType` and operands `Ops` in
/// `UniqueSCEVs`. Return if found, else nullptr.
SCEV *findExistingSCEVInCache(SCEVTypes SCEVType, ArrayRef<const SCEV *> Ops);
Expand Down
74 changes: 74 additions & 0 deletions llvm/include/llvm/Analysis/ScalarEvolutionPatternMatch.h
Original file line number Diff line number Diff line change
Expand Up @@ -252,6 +252,80 @@ m_scev_UDiv(const Op0_t &Op0, const Op1_t &Op1) {
return m_scev_Binary<SCEVUDivExpr>(Op0, Op1);
}

/// Match unsigned remainder pattern.
/// Matches patterns generated by getURemExpr.
template <typename Op0_t, typename Op1_t> struct SCEVURem_match {
Op0_t Op0;
Op1_t Op1;
ScalarEvolution &SE;

SCEVURem_match(Op0_t Op0, Op1_t Op1, ScalarEvolution &SE)
: Op0(Op0), Op1(Op1), SE(SE) {}

bool match(const SCEV *Expr) const {
if (Expr->getType()->isPointerTy())
return false;

// Try to match 'zext (trunc A to iB) to iY', which is used
// for URem with constant power-of-2 second operands. Make sure the size of
// the operand A matches the size of the whole expressions.
const SCEV *LHS;
if (SCEVPatternMatch::match(Expr, m_scev_ZExt(m_scev_Trunc(m_SCEV(LHS))))) {
Type *TruncTy = cast<SCEVZeroExtendExpr>(Expr)->getOperand()->getType();
// Bail out if the type of the LHS is larger than the type of the
// expression for now.
if (SE.getTypeSizeInBits(LHS->getType()) >
SE.getTypeSizeInBits(Expr->getType()))
return false;
if (LHS->getType() != Expr->getType())
LHS = SE.getZeroExtendExpr(LHS, Expr->getType());
const SCEV *RHS =
SE.getConstant(APInt(SE.getTypeSizeInBits(Expr->getType()), 1)
<< SE.getTypeSizeInBits(TruncTy));
return Op0.match(LHS) && Op1.match(RHS);
}
const auto *Add = dyn_cast<SCEVAddExpr>(Expr);
if (Add == nullptr || Add->getNumOperands() != 2)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Would it make sense to use m_scev_Add here?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yep, this is something I have a follow up in progress to combine the Mul/Add checks using pattern matching

return false;

const SCEV *A = Add->getOperand(1);
const auto *Mul = dyn_cast<SCEVMulExpr>(Add->getOperand(0));

if (Mul == nullptr)
return false;

const auto MatchURemWithDivisor = [&](const SCEV *B) {
// (SomeExpr + (-(SomeExpr / B) * B)).
if (Expr == SE.getURemExpr(A, B))
return Op0.match(A) && Op1.match(B);
return false;
};

// (SomeExpr + (-1 * (SomeExpr / B) * B)).
if (Mul->getNumOperands() == 3 && isa<SCEVConstant>(Mul->getOperand(0)))
return MatchURemWithDivisor(Mul->getOperand(1)) ||
MatchURemWithDivisor(Mul->getOperand(2));

// (SomeExpr + ((-SomeExpr / B) * B)) or (SomeExpr + ((SomeExpr / B) * -B)).
if (Mul->getNumOperands() == 2)
return MatchURemWithDivisor(Mul->getOperand(1)) ||
MatchURemWithDivisor(Mul->getOperand(0)) ||
MatchURemWithDivisor(SE.getNegativeSCEV(Mul->getOperand(1))) ||
MatchURemWithDivisor(SE.getNegativeSCEV(Mul->getOperand(0)));
return false;
}
};

/// Match the mathematical pattern A - (A / B) * B, where A and B can be
/// arbitrary expressions. Also match zext (trunc A to iB) to iY, which is used
/// for URem with constant power-of-2 second operands. It's not always easy, as
/// A and B can be folded (imagine A is X / 2, and B is 4, A / B becomes X / 8).
template <typename Op0_t, typename Op1_t>
inline SCEVURem_match<Op0_t, Op1_t> m_scev_URem(Op0_t LHS, Op1_t RHS,
ScalarEvolution &SE) {
return SCEVURem_match<Op0_t, Op1_t>(LHS, RHS, SE);
}

inline class_match<const Loop> m_Loop() { return class_match<const Loop>(); }

/// Match an affine SCEVAddRecExpr.
Expand Down
102 changes: 18 additions & 84 deletions llvm/lib/Analysis/ScalarEvolution.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1774,7 +1774,7 @@ const SCEV *ScalarEvolution::getZeroExtendExprImpl(const SCEV *Op, Type *Ty,
{
const SCEV *LHS;
const SCEV *RHS;
if (matchURem(Op, LHS, RHS))
if (match(Op, m_scev_URem(m_SCEV(LHS), m_SCEV(RHS), *this)))
return getURemExpr(getZeroExtendExpr(LHS, Ty, Depth + 1),
getZeroExtendExpr(RHS, Ty, Depth + 1));
}
Expand Down Expand Up @@ -2699,17 +2699,12 @@ const SCEV *ScalarEvolution::getAddExpr(SmallVectorImpl<const SCEV *> &Ops,
}

// Canonicalize (-1 * urem X, Y) + X --> (Y * X/Y)
if (Ops.size() == 2) {
const SCEVMulExpr *Mul = dyn_cast<SCEVMulExpr>(Ops[0]);
if (Mul && Mul->getNumOperands() == 2 &&
Mul->getOperand(0)->isAllOnesValue()) {
const SCEV *X;
const SCEV *Y;
if (matchURem(Mul->getOperand(1), X, Y) && X == Ops[1]) {
return getMulExpr(Y, getUDivExpr(X, Y));
}
}
}
const SCEV *Y;
if (Ops.size() == 2 &&
match(Ops[0],
m_scev_Mul(m_scev_AllOnes(),
m_scev_URem(m_scev_Specific(Ops[1]), m_SCEV(Y), *this))))
return getMulExpr(Y, getUDivExpr(Ops[1], Y));

// Skip past any other cast SCEVs.
while (Idx < Ops.size() && Ops[Idx]->getSCEVType() < scAddExpr)
Expand Down Expand Up @@ -15410,65 +15405,6 @@ void PredicatedScalarEvolution::print(raw_ostream &OS, unsigned Depth) const {
}
}

// Match the mathematical pattern A - (A / B) * B, where A and B can be
// arbitrary expressions. Also match zext (trunc A to iB) to iY, which is used
// for URem with constant power-of-2 second operands.
// It's not always easy, as A and B can be folded (imagine A is X / 2, and B is
// 4, A / B becomes X / 8).
bool ScalarEvolution::matchURem(const SCEV *Expr, const SCEV *&LHS,
const SCEV *&RHS) {
if (Expr->getType()->isPointerTy())
return false;

// Try to match 'zext (trunc A to iB) to iY', which is used
// for URem with constant power-of-2 second operands. Make sure the size of
// the operand A matches the size of the whole expressions.
if (match(Expr, m_scev_ZExt(m_scev_Trunc(m_SCEV(LHS))))) {
Type *TruncTy = cast<SCEVZeroExtendExpr>(Expr)->getOperand()->getType();
// Bail out if the type of the LHS is larger than the type of the
// expression for now.
if (getTypeSizeInBits(LHS->getType()) > getTypeSizeInBits(Expr->getType()))
return false;
if (LHS->getType() != Expr->getType())
LHS = getZeroExtendExpr(LHS, Expr->getType());
RHS = getConstant(APInt(getTypeSizeInBits(Expr->getType()), 1)
<< getTypeSizeInBits(TruncTy));
return true;
}
const auto *Add = dyn_cast<SCEVAddExpr>(Expr);
if (Add == nullptr || Add->getNumOperands() != 2)
return false;

const SCEV *A = Add->getOperand(1);
const auto *Mul = dyn_cast<SCEVMulExpr>(Add->getOperand(0));

if (Mul == nullptr)
return false;

const auto MatchURemWithDivisor = [&](const SCEV *B) {
// (SomeExpr + (-(SomeExpr / B) * B)).
if (Expr == getURemExpr(A, B)) {
LHS = A;
RHS = B;
return true;
}
return false;
};

// (SomeExpr + (-1 * (SomeExpr / B) * B)).
if (Mul->getNumOperands() == 3 && isa<SCEVConstant>(Mul->getOperand(0)))
return MatchURemWithDivisor(Mul->getOperand(1)) ||
MatchURemWithDivisor(Mul->getOperand(2));

// (SomeExpr + ((-SomeExpr / B) * B)) or (SomeExpr + ((SomeExpr / B) * -B)).
if (Mul->getNumOperands() == 2)
return MatchURemWithDivisor(Mul->getOperand(1)) ||
MatchURemWithDivisor(Mul->getOperand(0)) ||
MatchURemWithDivisor(getNegativeSCEV(Mul->getOperand(1))) ||
MatchURemWithDivisor(getNegativeSCEV(Mul->getOperand(0)));
return false;
}

ScalarEvolution::LoopGuards
ScalarEvolution::LoopGuards::collect(const Loop *L, ScalarEvolution &SE) {
BasicBlock *Header = L->getHeader();
Expand Down Expand Up @@ -15689,20 +15625,18 @@ void ScalarEvolution::LoopGuards::collectFromBlock(
if (Predicate == CmpInst::ICMP_EQ && match(RHS, m_scev_Zero())) {
// If LHS is A % B, i.e. A % B == 0, rewrite A to (A /u B) * B to
// explicitly express that.
const SCEV *URemLHS = nullptr;
const SCEVUnknown *URemLHS = nullptr;
const SCEV *URemRHS = nullptr;
if (SE.matchURem(LHS, URemLHS, URemRHS)) {
if (const SCEVUnknown *LHSUnknown = dyn_cast<SCEVUnknown>(URemLHS)) {
auto I = RewriteMap.find(LHSUnknown);
const SCEV *RewrittenLHS =
I != RewriteMap.end() ? I->second : LHSUnknown;
RewrittenLHS = ApplyDivisibiltyOnMinMaxExpr(RewrittenLHS, URemRHS);
const auto *Multiple =
SE.getMulExpr(SE.getUDivExpr(RewrittenLHS, URemRHS), URemRHS);
RewriteMap[LHSUnknown] = Multiple;
ExprsToRewrite.push_back(LHSUnknown);
return;
}
if (match(LHS,
m_scev_URem(m_SCEVUnknown(URemLHS), m_SCEV(URemRHS), SE))) {
auto I = RewriteMap.find(URemLHS);
const SCEV *RewrittenLHS = I != RewriteMap.end() ? I->second : URemLHS;
RewrittenLHS = ApplyDivisibiltyOnMinMaxExpr(RewrittenLHS, URemRHS);
const auto *Multiple =
SE.getMulExpr(SE.getUDivExpr(RewrittenLHS, URemRHS), URemRHS);
RewriteMap[URemLHS] = Multiple;
ExprsToRewrite.push_back(URemLHS);
return;
}
}

Expand Down
2 changes: 1 addition & 1 deletion llvm/lib/Transforms/Utils/ScalarEvolutionExpander.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -526,7 +526,7 @@ Value *SCEVExpander::visitAddExpr(const SCEVAddExpr *S) {
// Recognize the canonical representation of an unsimplifed urem.
const SCEV *URemLHS = nullptr;
const SCEV *URemRHS = nullptr;
if (SE.matchURem(S, URemLHS, URemRHS)) {
if (match(S, m_scev_URem(m_SCEV(URemLHS), m_SCEV(URemRHS), SE))) {
Value *LHS = expand(URemLHS);
Value *RHS = expand(URemRHS);
return InsertBinop(Instruction::URem, LHS, RHS, SCEV::FlagAnyWrap,
Expand Down
12 changes: 5 additions & 7 deletions llvm/unittests/Analysis/ScalarEvolutionTest.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
#include "llvm/Analysis/LoopInfo.h"
#include "llvm/Analysis/ScalarEvolutionExpressions.h"
#include "llvm/Analysis/ScalarEvolutionNormalization.h"
#include "llvm/Analysis/ScalarEvolutionPatternMatch.h"
#include "llvm/Analysis/TargetLibraryInfo.h"
#include "llvm/AsmParser/Parser.h"
#include "llvm/IR/Constants.h"
Expand All @@ -26,6 +27,8 @@

namespace llvm {

using namespace SCEVPatternMatch;

// We use this fixture to ensure that we clean up ScalarEvolution before
// deleting the PassManager.
class ScalarEvolutionsTest : public testing::Test {
Expand Down Expand Up @@ -64,11 +67,6 @@ static std::optional<APInt> computeConstantDifference(ScalarEvolution &SE,
return SE.computeConstantDifference(LHS, RHS);
}

static bool matchURem(ScalarEvolution &SE, const SCEV *Expr, const SCEV *&LHS,
const SCEV *&RHS) {
return SE.matchURem(Expr, LHS, RHS);
}

static bool isImpliedCond(
ScalarEvolution &SE, ICmpInst::Predicate Pred, const SCEV *LHS,
const SCEV *RHS, ICmpInst::Predicate FoundPred, const SCEV *FoundLHS,
Expand Down Expand Up @@ -1524,7 +1522,7 @@ TEST_F(ScalarEvolutionsTest, MatchURem) {
auto *URemI = getInstructionByName(F, N);
auto *S = SE.getSCEV(URemI);
const SCEV *LHS, *RHS;
EXPECT_TRUE(matchURem(SE, S, LHS, RHS));
EXPECT_TRUE(match(S, m_scev_URem(m_SCEV(LHS), m_SCEV(RHS), SE)));
EXPECT_EQ(LHS, SE.getSCEV(URemI->getOperand(0)));
EXPECT_EQ(RHS, SE.getSCEV(URemI->getOperand(1)));
EXPECT_EQ(LHS->getType(), S->getType());
Expand All @@ -1537,7 +1535,7 @@ TEST_F(ScalarEvolutionsTest, MatchURem) {
auto *URem1 = getInstructionByName(F, "rem4");
auto *S = SE.getSCEV(Ext);
const SCEV *LHS, *RHS;
EXPECT_TRUE(matchURem(SE, S, LHS, RHS));
EXPECT_TRUE(match(S, m_scev_URem(m_SCEV(LHS), m_SCEV(RHS), SE)));
EXPECT_NE(LHS, SE.getSCEV(URem1->getOperand(0)));
// RHS and URem1->getOperand(1) have different widths, so compare the
// integer values.
Expand Down
Loading