diff --git a/llvm/include/llvm/Analysis/ScalarEvolutionPatternMatch.h b/llvm/include/llvm/Analysis/ScalarEvolutionPatternMatch.h index cfb1b4c6ea6b4..5fc3840bd1d4c 100644 --- a/llvm/include/llvm/Analysis/ScalarEvolutionPatternMatch.h +++ b/llvm/include/llvm/Analysis/ScalarEvolutionPatternMatch.h @@ -172,6 +172,21 @@ struct SCEVBinaryExpr_match { } }; +template struct SCEVAffineAddRec_match { + Op0_t Op0; + Op1_t Op1; + const Loop *L; + + SCEVAffineAddRec_match(Op0_t Op0, Op1_t Op1, const Loop *L) + : Op0(Op0), Op1(Op1), L(L) {} + + bool match(const SCEV *S) const { + auto *E = dyn_cast(S); + return E && E->getNumOperands() == 2 && Op0.match(E->getOperand(0)) && + Op1.match(E->getOperand(1)) && (!L || E->getLoop() == L); + } +}; + template inline SCEVBinaryExpr_match m_scev_Binary(const Op0_t &Op0, const Op1_t &Op1) { @@ -197,9 +212,10 @@ m_scev_UDiv(const Op0_t &Op0, const Op1_t &Op1) { } template -inline SCEVBinaryExpr_match -m_scev_AffineAddRec(const Op0_t &Op0, const Op1_t &Op1) { - return m_scev_Binary(Op0, Op1); +inline SCEVAffineAddRec_match +m_scev_AffineAddRec(const Op0_t &Op0, const Op1_t &Op1, + const Loop *L = nullptr) { + return SCEVAffineAddRec_match(Op0, Op1, L); } } // namespace SCEVPatternMatch } // namespace llvm diff --git a/llvm/lib/Transforms/Scalar/IndVarSimplify.cpp b/llvm/lib/Transforms/Scalar/IndVarSimplify.cpp index e774e5fd99cbb..68db70d7abf9b 100644 --- a/llvm/lib/Transforms/Scalar/IndVarSimplify.cpp +++ b/llvm/lib/Transforms/Scalar/IndVarSimplify.cpp @@ -808,8 +808,7 @@ static bool isLoopCounter(PHINode* Phi, Loop *L, return false; const SCEV *S = SE->getSCEV(Phi); - if (!match(S, m_scev_AffineAddRec(m_SCEV(), m_scev_One())) || - cast(S)->getLoop() != L) + if (!match(S, m_scev_AffineAddRec(m_SCEV(), m_scev_One(), L))) return false; int LatchIdx = Phi->getBasicBlockIndex(L->getLoopLatch()); diff --git a/llvm/lib/Transforms/Scalar/LoopIdiomRecognize.cpp b/llvm/lib/Transforms/Scalar/LoopIdiomRecognize.cpp index 0d5e0156b22be..2b2f50c7047c9 100644 --- a/llvm/lib/Transforms/Scalar/LoopIdiomRecognize.cpp +++ b/llvm/lib/Transforms/Scalar/LoopIdiomRecognize.cpp @@ -455,8 +455,8 @@ LoopIdiomRecognize::isLegalStore(StoreInst *SI) { // random store we can't handle. const SCEV *StoreEv = SE->getSCEV(StorePtr); const SCEVConstant *Stride; - if (!match(StoreEv, m_scev_AffineAddRec(m_SCEV(), m_SCEVConstant(Stride))) || - cast(StoreEv)->getLoop() != CurLoop) + if (!match(StoreEv, + m_scev_AffineAddRec(m_SCEV(), m_SCEVConstant(Stride), CurLoop))) return LegalStoreKind::None; // See if the store can be turned into a memset. @@ -513,8 +513,7 @@ LoopIdiomRecognize::isLegalStore(StoreInst *SI) { // The store and load must share the same stride. if (!match(LoadEv, - m_scev_AffineAddRec(m_SCEV(), m_scev_Specific(Stride))) || - cast(LoadEv)->getLoop() != CurLoop) + m_scev_AffineAddRec(m_SCEV(), m_scev_Specific(Stride), CurLoop))) return LegalStoreKind::None; // Success. This store can be converted into a memcpy. @@ -787,11 +786,13 @@ bool LoopIdiomRecognize::processLoopMemCpy(MemCpyInst *MCI, // See if the load and store pointer expressions are AddRec like {base,+,1} on // the current loop, which indicates a strided load and store. If we have // something else, it's a random load or store we can't handle. - const SCEVAddRecExpr *StoreEv = dyn_cast(SE->getSCEV(Dest)); - if (!StoreEv || StoreEv->getLoop() != CurLoop || !StoreEv->isAffine()) - return false; - const SCEVAddRecExpr *LoadEv = dyn_cast(SE->getSCEV(Source)); - if (!LoadEv || LoadEv->getLoop() != CurLoop || !LoadEv->isAffine()) + const SCEV *StoreEv = SE->getSCEV(Dest); + const SCEV *LoadEv = SE->getSCEV(Source); + const APInt *StoreStrideValue, *LoadStrideValue; + if (!match(StoreEv, m_scev_AffineAddRec( + m_SCEV(), m_scev_APInt(StoreStrideValue), CurLoop)) || + !match(LoadEv, m_scev_AffineAddRec( + m_SCEV(), m_scev_APInt(LoadStrideValue), CurLoop))) return false; // Reject memcpys that are so large that they overflow an unsigned. @@ -801,10 +802,6 @@ bool LoopIdiomRecognize::processLoopMemCpy(MemCpyInst *MCI, // Check if the stride matches the size of the memcpy. If so, then we know // that every byte is touched in the loop. - const APInt *StoreStrideValue, *LoadStrideValue; - if (!match(StoreEv->getOperand(1), m_scev_APInt(StoreStrideValue)) || - !match(LoadEv->getOperand(1), m_scev_APInt(LoadStrideValue))) - return false; // Huge stride value - give up if (StoreStrideValue->getBitWidth() > 64 || @@ -830,8 +827,8 @@ bool LoopIdiomRecognize::processLoopMemCpy(MemCpyInst *MCI, return processLoopStoreOfLoopLoad( Dest, Source, SE->getConstant(Dest->getType(), SizeInBytes), - MCI->getDestAlign(), MCI->getSourceAlign(), MCI, MCI, StoreEv, LoadEv, - BECount); + MCI->getDestAlign(), MCI->getSourceAlign(), MCI, MCI, + cast(StoreEv), cast(LoadEv), BECount); } /// processLoopMemSet - See if this memset can be promoted to a large memset. @@ -852,12 +849,11 @@ bool LoopIdiomRecognize::processLoopMemSet(MemSetInst *MSI, // random store we can't handle. const SCEV *Ev = SE->getSCEV(Pointer); const SCEV *PointerStrideSCEV; - if (!match(Ev, m_scev_AffineAddRec(m_SCEV(), m_SCEV(PointerStrideSCEV)))) { + if (!match(Ev, m_scev_AffineAddRec(m_SCEV(), m_SCEV(PointerStrideSCEV), + CurLoop))) { LLVM_DEBUG(dbgs() << " Pointer is not affine, abort\n"); return false; } - if (cast(Ev)->getLoop() != CurLoop) - return false; const SCEV *MemsetSizeSCEV = SE->getSCEV(MSI->getLength()); if (!PointerStrideSCEV || !MemsetSizeSCEV)