diff --git a/llvm/include/llvm/Analysis/ScalarEvolutionPatternMatch.h b/llvm/include/llvm/Analysis/ScalarEvolutionPatternMatch.h index 9354eef98fe91..f285eacc4c565 100644 --- a/llvm/include/llvm/Analysis/ScalarEvolutionPatternMatch.h +++ b/llvm/include/llvm/Analysis/ScalarEvolutionPatternMatch.h @@ -380,6 +380,19 @@ m_scev_AffineAddRec(const Op0_t &Op0, const Op1_t &Op1, const Loop_t &L) { return SCEVAffineAddRec_match(Op0, Op1, L); } +struct is_undef_or_poison { + bool match(const SCEV *S) const { + const SCEVUnknown *Unknown; + return SCEVPatternMatch::match(S, m_SCEVUnknown(Unknown)) && + isa(Unknown->getValue()); + } +}; + +/// Match an SCEVUnknown wrapping undef or poison. +inline is_undef_or_poison m_scev_UndefOrPoison() { + return is_undef_or_poison(); +} + } // namespace SCEVPatternMatch } // namespace llvm diff --git a/llvm/lib/Analysis/LoopAccessAnalysis.cpp b/llvm/lib/Analysis/LoopAccessAnalysis.cpp index 5d88e5f54e3d6..df793de7c817b 100644 --- a/llvm/lib/Analysis/LoopAccessAnalysis.cpp +++ b/llvm/lib/Analysis/LoopAccessAnalysis.cpp @@ -2998,9 +2998,8 @@ void LoopAccessInfo::collectStridedAccess(Value *MemAccess) { if (!StrideExpr) return; - if (auto *Unknown = dyn_cast(StrideExpr)) - if (isa(Unknown->getValue())) - return; + if (match(StrideExpr, m_scev_UndefOrPoison())) + return; LLVM_DEBUG(dbgs() << "LAA: Found a strided access that is a candidate for " "versioning:"); diff --git a/llvm/lib/Analysis/ScalarEvolution.cpp b/llvm/lib/Analysis/ScalarEvolution.cpp index 1d7a8b981b5ee..e1f90264be7a2 100644 --- a/llvm/lib/Analysis/ScalarEvolution.cpp +++ b/llvm/lib/Analysis/ScalarEvolution.cpp @@ -13695,11 +13695,8 @@ SCEVAddRecExpr::getPostIncExpr(ScalarEvolution &SE) const { // Return true when S contains at least an undef value. bool ScalarEvolution::containsUndefs(const SCEV *S) const { - return SCEVExprContains(S, [](const SCEV *S) { - if (const auto *SU = dyn_cast(S)) - return isa(SU->getValue()); - return false; - }); + return SCEVExprContains( + S, [](const SCEV *S) { return match(S, m_scev_UndefOrPoison()); }); } // Return true when S contains a value that is a nullptr.