Skip to content
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions llvm/include/llvm/Analysis/ScalarEvolutionPatternMatch.h
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,9 @@ inline class_match<const SCEV> m_SCEV() { return class_match<const SCEV>(); }
inline class_match<const SCEVConstant> m_SCEVConstant() {
return class_match<const SCEVConstant>();
}
inline class_match<const SCEVVScale> m_SCEVVScale() {
return class_match<const SCEVVScale>();
}

template <typename Class> struct bind_ty {
Class *&VR;
Expand Down
108 changes: 45 additions & 63 deletions llvm/lib/Transforms/Scalar/LoopStrengthReduce.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -923,10 +923,11 @@ static const SCEV *getExactSDiv(const SCEV *LHS, const SCEV *RHS,
/// If S involves the addition of a constant integer value, return that integer
/// value, and mutate S to point to a new SCEV with that value excluded.
static Immediate ExtractImmediate(const SCEV *&S, ScalarEvolution &SE) {
if (const SCEVConstant *C = dyn_cast<SCEVConstant>(S)) {
if (C->getAPInt().getSignificantBits() <= 64) {
S = SE.getConstant(C->getType(), 0);
return Immediate::getFixed(C->getValue()->getSExtValue());
const APInt *C;
if (match(S, m_scev_APInt(C))) {
if (C->getSignificantBits() <= 64) {
S = SE.getConstant(S->getType(), 0);
return Immediate::getFixed(C->getSExtValue());
}
} else if (const SCEVAddExpr *Add = dyn_cast<SCEVAddExpr>(S)) {
SmallVector<const SCEV *, 8> NewOps(Add->operands());
Expand All @@ -942,14 +943,10 @@ static Immediate ExtractImmediate(const SCEV *&S, ScalarEvolution &SE) {
// FIXME: AR->getNoWrapFlags(SCEV::FlagNW)
SCEV::FlagAnyWrap);
return Result;
} else if (const SCEVMulExpr *M = dyn_cast<SCEVMulExpr>(S)) {
if (EnableVScaleImmediates && M->getNumOperands() == 2) {
if (const SCEVConstant *C = dyn_cast<SCEVConstant>(M->getOperand(0)))
if (isa<SCEVVScale>(M->getOperand(1))) {
S = SE.getConstant(M->getType(), 0);
return Immediate::getScalable(C->getValue()->getSExtValue());
}
}
} else if (EnableVScaleImmediates &&
match(S, m_scev_Mul(m_scev_APInt(C), m_SCEVVScale()))) {
S = SE.getConstant(S->getType(), 0);
return Immediate::getScalable(C->getSExtValue());
}
return Immediate::getZero();
}
Expand Down Expand Up @@ -1133,23 +1130,22 @@ static bool isHighCostExpansion(const SCEV *S,
return false;
}

if (const SCEVMulExpr *Mul = dyn_cast<SCEVMulExpr>(S)) {
if (Mul->getNumOperands() == 2) {
// Multiplication by a constant is ok
if (isa<SCEVConstant>(Mul->getOperand(0)))
return isHighCostExpansion(Mul->getOperand(1), Processed, SE);

// If we have the value of one operand, check if an existing
// multiplication already generates this expression.
if (const SCEVUnknown *U = dyn_cast<SCEVUnknown>(Mul->getOperand(1))) {
Value *UVal = U->getValue();
for (User *UR : UVal->users()) {
// If U is a constant, it may be used by a ConstantExpr.
Instruction *UI = dyn_cast<Instruction>(UR);
if (UI && UI->getOpcode() == Instruction::Mul &&
SE.isSCEVable(UI->getType())) {
return SE.getSCEV(UI) == Mul;
}
const SCEV *Op0, *Op1;
if (match(S, m_scev_Mul(m_SCEV(Op0), m_SCEV(Op1)))) {
// Multiplication by a constant is ok
if (isa<SCEVConstant>(Op0))
return isHighCostExpansion(Op1, Processed, SE);

// If we have the value of one operand, check if an existing
// multiplication already generates this expression.
if (const auto *U = dyn_cast<SCEVUnknown>(Op1)) {
Value *UVal = U->getValue();
for (User *UR : UVal->users()) {
// If U is a constant, it may be used by a ConstantExpr.
Instruction *UI = dyn_cast<Instruction>(UR);
if (UI && UI->getOpcode() == Instruction::Mul &&
SE.isSCEVable(UI->getType())) {
return SE.getSCEV(UI) == S;
}
}
}
Expand Down Expand Up @@ -3333,14 +3329,11 @@ static bool canFoldIVIncExpr(const SCEV *IncExpr, Instruction *UserInst,
IncOffset = Immediate::getFixed(IncConst->getValue()->getSExtValue());
} else {
// Look for mul(vscale, constant), to detect a scalable offset.
auto *IncVScale = dyn_cast<SCEVMulExpr>(IncExpr);
if (!IncVScale || IncVScale->getNumOperands() != 2 ||
!isa<SCEVVScale>(IncVScale->getOperand(1)))
return false;
auto *Scale = dyn_cast<SCEVConstant>(IncVScale->getOperand(0));
if (!Scale || Scale->getType()->getScalarSizeInBits() > 64)
const APInt *C;
if (!match(IncExpr, m_scev_Mul(m_scev_APInt(C), m_SCEVVScale())) ||
C->getSignificantBits() > 64)
return false;
IncOffset = Immediate::getScalable(Scale->getValue()->getSExtValue());
IncOffset = Immediate::getScalable(C->getSExtValue());
}

if (!isAddressUse(TTI, UserInst, Operand))
Expand Down Expand Up @@ -3818,6 +3811,8 @@ static const SCEV *CollectSubexprs(const SCEV *S, const SCEVConstant *C,
return nullptr;
}
const SCEV *Start, *Step;
const SCEVConstant *Op0;
const SCEV *Op1;
if (match(S, m_scev_AffineAddRec(m_SCEV(Start), m_SCEV(Step)))) {
// Split a non-zero base out of an addrec.
if (Start->isZero())
Expand All @@ -3839,19 +3834,13 @@ static const SCEV *CollectSubexprs(const SCEV *S, const SCEVConstant *C,
// FIXME: AR->getNoWrapFlags(SCEV::FlagNW)
SCEV::FlagAnyWrap);
}
} else if (const SCEVMulExpr *Mul = dyn_cast<SCEVMulExpr>(S)) {
} else if (match(S, m_scev_Mul(m_SCEVConstant(Op0), m_SCEV(Op1)))) {
// Break (C * (a + b + c)) into C*a + C*b + C*c.
if (Mul->getNumOperands() != 2)
return S;
if (const SCEVConstant *Op0 =
dyn_cast<SCEVConstant>(Mul->getOperand(0))) {
C = C ? cast<SCEVConstant>(SE.getMulExpr(C, Op0)) : Op0;
const SCEV *Remainder =
CollectSubexprs(Mul->getOperand(1), C, Ops, L, SE, Depth+1);
if (Remainder)
Ops.push_back(SE.getMulExpr(C, Remainder));
return nullptr;
}
C = C ? cast<SCEVConstant>(SE.getMulExpr(C, Op0)) : Op0;
const SCEV *Remainder = CollectSubexprs(Op1, C, Ops, L, SE, Depth + 1);
if (Remainder)
Ops.push_back(SE.getMulExpr(C, Remainder));
return nullptr;
}
return S;
}
Expand Down Expand Up @@ -6478,13 +6467,10 @@ struct SCEVDbgValueBuilder {
/// Components of the expression are omitted if they are an identity function.
/// Chain (non-affine) SCEVs are not supported.
bool SCEVToValueExpr(const llvm::SCEVAddRecExpr &SAR, ScalarEvolution &SE) {
assert(SAR.isAffine() && "Expected affine SCEV");
// TODO: Is this check needed?
if (isa<SCEVAddRecExpr>(SAR.getStart()))
return false;

const SCEV *Start = SAR.getStart();
const SCEV *Stride = SAR.getStepRecurrence(SE);
const SCEV *Start, *Stride;
[[maybe_unused]] bool Match =
match(&SAR, m_scev_AffineAddRec(m_SCEV(Start), m_SCEV(Stride)));
assert(Match && "Expected affine SCEV");

// Skip pushing arithmetic noops.
if (!isIdentityFunction(llvm::dwarf::DW_OP_mul, Stride)) {
Expand Down Expand Up @@ -6549,14 +6535,10 @@ struct SCEVDbgValueBuilder {
/// Components of the expression are omitted if they are an identity function.
bool SCEVToIterCountExpr(const llvm::SCEVAddRecExpr &SAR,
ScalarEvolution &SE) {
assert(SAR.isAffine() && "Expected affine SCEV");
if (isa<SCEVAddRecExpr>(SAR.getStart())) {
LLVM_DEBUG(dbgs() << "scev-salvage: IV SCEV. Unsupported nested AddRec: "
<< SAR << '\n');
return false;
}
const SCEV *Start = SAR.getStart();
const SCEV *Stride = SAR.getStepRecurrence(SE);
const SCEV *Start, *Stride;
[[maybe_unused]] bool Match =
match(&SAR, m_scev_AffineAddRec(m_SCEV(Start), m_SCEV(Stride)));
assert(Match && "Expected affine SCEV");

// Skip pushing arithmetic noops.
if (!isIdentityFunction(llvm::dwarf::DW_OP_minus, Start)) {
Expand Down