Skip to content

Commit f256b76

Browse files
committed
[SCEVPatternMatch] Extend m_scev_AffineAddRec with Loop
1 parent e2a8855 commit f256b76

File tree

3 files changed

+25
-23
lines changed

3 files changed

+25
-23
lines changed

llvm/include/llvm/Analysis/ScalarEvolutionPatternMatch.h

Lines changed: 10 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
#define LLVM_ANALYSIS_SCALAREVOLUTIONPATTERNMATCH_H
1515

1616
#include "llvm/Analysis/ScalarEvolutionExpressions.h"
17+
#include <type_traits>
1718

1819
namespace llvm {
1920
namespace SCEVPatternMatch {
@@ -162,13 +163,18 @@ template <typename SCEVTy, typename Op0_t, typename Op1_t>
162163
struct SCEVBinaryExpr_match {
163164
Op0_t Op0;
164165
Op1_t Op1;
166+
const Loop *L;
165167

166-
SCEVBinaryExpr_match(Op0_t Op0, Op1_t Op1) : Op0(Op0), Op1(Op1) {}
168+
SCEVBinaryExpr_match(Op0_t Op0, Op1_t Op1, const Loop *L = nullptr)
169+
: Op0(Op0), Op1(Op1), L(L) {}
167170

168171
bool match(const SCEV *S) const {
169172
auto *E = dyn_cast<SCEVTy>(S);
173+
bool LoopMatches = true;
174+
if constexpr (std::is_same_v<SCEVTy, SCEVAddRecExpr>)
175+
LoopMatches = !L || (E && E->getLoop() == L);
170176
return E && E->getNumOperands() == 2 && Op0.match(E->getOperand(0)) &&
171-
Op1.match(E->getOperand(1));
177+
Op1.match(E->getOperand(1)) && LoopMatches;
172178
}
173179
};
174180

@@ -198,7 +204,8 @@ m_scev_UDiv(const Op0_t &Op0, const Op1_t &Op1) {
198204

199205
template <typename Op0_t, typename Op1_t>
200206
inline SCEVBinaryExpr_match<SCEVAddRecExpr, Op0_t, Op1_t>
201-
m_scev_AffineAddRec(const Op0_t &Op0, const Op1_t &Op1) {
207+
m_scev_AffineAddRec(const Op0_t &Op0, const Op1_t &Op1,
208+
const Loop *L = nullptr) {
202209
return m_scev_Binary<SCEVAddRecExpr>(Op0, Op1);
203210
}
204211
} // namespace SCEVPatternMatch

llvm/lib/Transforms/Scalar/IndVarSimplify.cpp

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -808,8 +808,7 @@ static bool isLoopCounter(PHINode* Phi, Loop *L,
808808
return false;
809809

810810
const SCEV *S = SE->getSCEV(Phi);
811-
if (!match(S, m_scev_AffineAddRec(m_SCEV(), m_scev_One())) ||
812-
cast<SCEVAddRecExpr>(S)->getLoop() != L)
811+
if (!match(S, m_scev_AffineAddRec(m_SCEV(), m_scev_One(), L)))
813812
return false;
814813

815814
int LatchIdx = Phi->getBasicBlockIndex(L->getLoopLatch());

llvm/lib/Transforms/Scalar/LoopIdiomRecognize.cpp

Lines changed: 14 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -455,8 +455,8 @@ LoopIdiomRecognize::isLegalStore(StoreInst *SI) {
455455
// random store we can't handle.
456456
const SCEV *StoreEv = SE->getSCEV(StorePtr);
457457
const SCEVConstant *Stride;
458-
if (!match(StoreEv, m_scev_AffineAddRec(m_SCEV(), m_SCEVConstant(Stride))) ||
459-
cast<SCEVAddRecExpr>(StoreEv)->getLoop() != CurLoop)
458+
if (!match(StoreEv,
459+
m_scev_AffineAddRec(m_SCEV(), m_SCEVConstant(Stride), CurLoop)))
460460
return LegalStoreKind::None;
461461

462462
// See if the store can be turned into a memset.
@@ -513,8 +513,7 @@ LoopIdiomRecognize::isLegalStore(StoreInst *SI) {
513513

514514
// The store and load must share the same stride.
515515
if (!match(LoadEv,
516-
m_scev_AffineAddRec(m_SCEV(), m_scev_Specific(Stride))) ||
517-
cast<SCEVAddRecExpr>(LoadEv)->getLoop() != CurLoop)
516+
m_scev_AffineAddRec(m_SCEV(), m_scev_Specific(Stride), CurLoop)))
518517
return LegalStoreKind::None;
519518

520519
// Success. This store can be converted into a memcpy.
@@ -787,11 +786,13 @@ bool LoopIdiomRecognize::processLoopMemCpy(MemCpyInst *MCI,
787786
// See if the load and store pointer expressions are AddRec like {base,+,1} on
788787
// the current loop, which indicates a strided load and store. If we have
789788
// something else, it's a random load or store we can't handle.
790-
const SCEVAddRecExpr *StoreEv = dyn_cast<SCEVAddRecExpr>(SE->getSCEV(Dest));
791-
if (!StoreEv || StoreEv->getLoop() != CurLoop || !StoreEv->isAffine())
792-
return false;
793-
const SCEVAddRecExpr *LoadEv = dyn_cast<SCEVAddRecExpr>(SE->getSCEV(Source));
794-
if (!LoadEv || LoadEv->getLoop() != CurLoop || !LoadEv->isAffine())
789+
const SCEV *StoreEv = SE->getSCEV(Dest);
790+
const SCEV *LoadEv = SE->getSCEV(Source);
791+
const APInt *StoreStrideValue, *LoadStrideValue;
792+
if (!match(StoreEv, m_scev_AffineAddRec(
793+
m_SCEV(), m_scev_APInt(StoreStrideValue), CurLoop)) ||
794+
!match(LoadEv, m_scev_AffineAddRec(
795+
m_SCEV(), m_scev_APInt(LoadStrideValue), CurLoop)))
795796
return false;
796797

797798
// Reject memcpys that are so large that they overflow an unsigned.
@@ -801,10 +802,6 @@ bool LoopIdiomRecognize::processLoopMemCpy(MemCpyInst *MCI,
801802

802803
// Check if the stride matches the size of the memcpy. If so, then we know
803804
// that every byte is touched in the loop.
804-
const APInt *StoreStrideValue, *LoadStrideValue;
805-
if (!match(StoreEv->getOperand(1), m_scev_APInt(StoreStrideValue)) ||
806-
!match(LoadEv->getOperand(1), m_scev_APInt(LoadStrideValue)))
807-
return false;
808805

809806
// Huge stride value - give up
810807
if (StoreStrideValue->getBitWidth() > 64 ||
@@ -830,8 +827,8 @@ bool LoopIdiomRecognize::processLoopMemCpy(MemCpyInst *MCI,
830827

831828
return processLoopStoreOfLoopLoad(
832829
Dest, Source, SE->getConstant(Dest->getType(), SizeInBytes),
833-
MCI->getDestAlign(), MCI->getSourceAlign(), MCI, MCI, StoreEv, LoadEv,
834-
BECount);
830+
MCI->getDestAlign(), MCI->getSourceAlign(), MCI, MCI,
831+
cast<SCEVAddRecExpr>(StoreEv), cast<SCEVAddRecExpr>(LoadEv), BECount);
835832
}
836833

837834
/// processLoopMemSet - See if this memset can be promoted to a large memset.
@@ -852,12 +849,11 @@ bool LoopIdiomRecognize::processLoopMemSet(MemSetInst *MSI,
852849
// random store we can't handle.
853850
const SCEV *Ev = SE->getSCEV(Pointer);
854851
const SCEV *PointerStrideSCEV;
855-
if (!match(Ev, m_scev_AffineAddRec(m_SCEV(), m_SCEV(PointerStrideSCEV)))) {
852+
if (!match(Ev, m_scev_AffineAddRec(m_SCEV(), m_SCEV(PointerStrideSCEV),
853+
CurLoop))) {
856854
LLVM_DEBUG(dbgs() << " Pointer is not affine, abort\n");
857855
return false;
858856
}
859-
if (cast<SCEVAddRecExpr>(Ev)->getLoop() != CurLoop)
860-
return false;
861857

862858
const SCEV *MemsetSizeSCEV = SE->getSCEV(MSI->getLength());
863859
if (!PointerStrideSCEV || !MemsetSizeSCEV)

0 commit comments

Comments
 (0)