From f267ffc3ad9ab22145f13324eb0ee84579d8c367 Mon Sep 17 00:00:00 2001 From: Florian Hahn Date: Tue, 10 Dec 2024 14:28:31 +0000 Subject: [PATCH 1/4] [SCEV] Add initial pattern matching for SCEV constants. (NFC) Add initial pattern matching for SCEV constants. Follow-up patches will add additional matchers for various SCEV expressions. --- .../Analysis/ScalarEvolutionPatternMatch.h | 59 +++++++++++++++++++ llvm/lib/Analysis/ScalarEvolution.cpp | 13 ++-- 2 files changed, 66 insertions(+), 6 deletions(-) create mode 100644 llvm/include/llvm/Analysis/ScalarEvolutionPatternMatch.h diff --git a/llvm/include/llvm/Analysis/ScalarEvolutionPatternMatch.h b/llvm/include/llvm/Analysis/ScalarEvolutionPatternMatch.h new file mode 100644 index 0000000000000..636b9f8e1544f --- /dev/null +++ b/llvm/include/llvm/Analysis/ScalarEvolutionPatternMatch.h @@ -0,0 +1,59 @@ +//===- ScalarEvolutionPatternMatch.h - Match on SCEVs -----------*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file provides a simple and efficient mechanism for performing general +// tree-based pattern matches on SCEVs, based on LLVM's IR pattern matchers. +// +//===----------------------------------------------------------------------===// + +#ifndef LLVM_ANALYSIS_SCALAREVOLUTIONPATTERNMATCH_H +#define LLVM_ANALYSIS_SCALAREVOLUTIONPATTERNMATCH_H + +#include "llvm/Analysis/ScalarEvolutionExpressions.h" + +namespace llvm { +namespace SCEVPatternMatch { + +template +bool match(const SCEV *S, const Pattern &P) { + return P.match(S); +} + +/// Match a specified integer value. \p BitWidth optionally specifies the +/// bitwidth the matched constant must have. If it is 0, the matched constant +/// can have any bitwidth. +template struct specific_intval { + APInt Val; + + specific_intval(APInt V) : Val(std::move(V)) {} + + bool match(const SCEV *S) const { + const auto *C = dyn_cast(S); + if (!C) + return false; + + if (BitWidth != 0 && C->getAPInt().getBitWidth() != BitWidth) + return false; + return APInt::isSameValue(C->getAPInt(), Val); + } +}; + +inline specific_intval<0> m_scev_Zero() { + return specific_intval<0>(APInt(64, 0)); +} +inline specific_intval<0> m_scev_One() { + return specific_intval<0>(APInt(64, 1)); +} +inline specific_intval<0> m_scev_MinusOne() { + return specific_intval<0>(APInt(64, -1)); +} + +} // namespace SCEVPatternMatch +} // namespace llvm + +#endif diff --git a/llvm/lib/Analysis/ScalarEvolution.cpp b/llvm/lib/Analysis/ScalarEvolution.cpp index cad10486cbf3f..741431ac8aa15 100644 --- a/llvm/lib/Analysis/ScalarEvolution.cpp +++ b/llvm/lib/Analysis/ScalarEvolution.cpp @@ -79,6 +79,7 @@ #include "llvm/Analysis/LoopInfo.h" #include "llvm/Analysis/MemoryBuiltins.h" #include "llvm/Analysis/ScalarEvolutionExpressions.h" +#include "llvm/Analysis/ScalarEvolutionPatternMatch.h" #include "llvm/Analysis/TargetLibraryInfo.h" #include "llvm/Analysis/ValueTracking.h" #include "llvm/Config/llvm-config.h" @@ -133,6 +134,7 @@ using namespace llvm; using namespace PatternMatch; +using namespace SCEVPatternMatch; #define DEBUG_TYPE "scalar-evolution" @@ -3423,9 +3425,8 @@ const SCEV *ScalarEvolution::getUDivExpr(const SCEV *LHS, return S; // 0 udiv Y == 0 - if (const SCEVConstant *LHSC = dyn_cast(LHS)) - if (LHSC->getValue()->isZero()) - return LHS; + if (match(LHS, m_scev_Zero())) + return LHS; if (const SCEVConstant *RHSC = dyn_cast(RHS)) { if (RHSC->getValue()->isOne()) @@ -10593,7 +10594,6 @@ ScalarEvolution::ExitLimit ScalarEvolution::howFarToZero(const SCEV *V, // Get the initial value for the loop. const SCEV *Start = getSCEVAtScope(AddRec->getStart(), L->getParentLoop()); const SCEV *Step = getSCEVAtScope(AddRec->getOperand(1), L->getParentLoop()); - const SCEVConstant *StepC = dyn_cast(Step); if (!isLoopInvariant(Step, L)) return getCouldNotCompute(); @@ -10615,8 +10615,8 @@ ScalarEvolution::ExitLimit ScalarEvolution::howFarToZero(const SCEV *V, // Handle unitary steps, which cannot wraparound. // 1*N = -Start; -1*N = Start (mod 2^BW), so: // N = Distance (as unsigned) - if (StepC && - (StepC->getValue()->isOne() || StepC->getValue()->isMinusOne())) { + + if (match(Step, m_CombineOr(m_scev_One(), m_scev_MinusOne()))) { APInt MaxBECount = getUnsignedRangeMax(applyLoopGuards(Distance, Guards)); MaxBECount = APIntOps::umin(MaxBECount, getUnsignedRangeMax(Distance)); @@ -10668,6 +10668,7 @@ ScalarEvolution::ExitLimit ScalarEvolution::howFarToZero(const SCEV *V, } // Solve the general equation. + const SCEVConstant *StepC = dyn_cast(Step); if (!StepC || StepC->getValue()->isZero()) return getCouldNotCompute(); const SCEV *E = SolveLinEquationWithOverflow( From 2b69817daac7f2caa84a9dc129f270a10467224c Mon Sep 17 00:00:00 2001 From: Florian Hahn Date: Tue, 10 Dec 2024 20:46:33 +0000 Subject: [PATCH 2/4] !fixup build int matcher on top of PatternMatch::specific_int64 --- .../Analysis/ScalarEvolutionPatternMatch.h | 34 ++++++------------- 1 file changed, 10 insertions(+), 24 deletions(-) diff --git a/llvm/include/llvm/Analysis/ScalarEvolutionPatternMatch.h b/llvm/include/llvm/Analysis/ScalarEvolutionPatternMatch.h index 636b9f8e1544f..a6658df32db31 100644 --- a/llvm/include/llvm/Analysis/ScalarEvolutionPatternMatch.h +++ b/llvm/include/llvm/Analysis/ScalarEvolutionPatternMatch.h @@ -15,6 +15,7 @@ #define LLVM_ANALYSIS_SCALAREVOLUTIONPATTERNMATCH_H #include "llvm/Analysis/ScalarEvolutionExpressions.h" +#include "llvm/IR/PatternMatch.h" namespace llvm { namespace SCEVPatternMatch { @@ -24,34 +25,19 @@ bool match(const SCEV *S, const Pattern &P) { return P.match(S); } -/// Match a specified integer value. \p BitWidth optionally specifies the -/// bitwidth the matched constant must have. If it is 0, the matched constant -/// can have any bitwidth. -template struct specific_intval { - APInt Val; +struct specific_intval64 : public PatternMatch::specific_intval64 { + specific_intval64(uint64_t V) : PatternMatch::specific_intval64(V) {} - specific_intval(APInt V) : Val(std::move(V)) {} - - bool match(const SCEV *S) const { - const auto *C = dyn_cast(S); - if (!C) - return false; - - if (BitWidth != 0 && C->getAPInt().getBitWidth() != BitWidth) - return false; - return APInt::isSameValue(C->getAPInt(), Val); + bool match(const SCEV *S) { + auto *Cast = dyn_cast(S); + return Cast && + PatternMatch::specific_intval64::match(Cast->getValue()); } }; -inline specific_intval<0> m_scev_Zero() { - return specific_intval<0>(APInt(64, 0)); -} -inline specific_intval<0> m_scev_One() { - return specific_intval<0>(APInt(64, 1)); -} -inline specific_intval<0> m_scev_MinusOne() { - return specific_intval<0>(APInt(64, -1)); -} +inline specific_intval64 m_scev_Zero() { return specific_intval64(0); } +inline specific_intval64 m_scev_One() { return specific_intval64(1); } +inline specific_intval64 m_scev_MinusOne() { return specific_intval64(-1); } } // namespace SCEVPatternMatch } // namespace llvm From 8758a1d15181d742cc4b8c44e6f597994726a637 Mon Sep 17 00:00:00 2001 From: Florian Hahn Date: Wed, 11 Dec 2024 11:59:25 +0000 Subject: [PATCH 3/4] !fixup use predicates --- .../Analysis/ScalarEvolutionPatternMatch.h | 42 +++++++++++++------ llvm/lib/Analysis/ScalarEvolution.cpp | 27 +++--------- 2 files changed, 36 insertions(+), 33 deletions(-) diff --git a/llvm/include/llvm/Analysis/ScalarEvolutionPatternMatch.h b/llvm/include/llvm/Analysis/ScalarEvolutionPatternMatch.h index a6658df32db31..b4adfe2c2d8a4 100644 --- a/llvm/include/llvm/Analysis/ScalarEvolutionPatternMatch.h +++ b/llvm/include/llvm/Analysis/ScalarEvolutionPatternMatch.h @@ -1,5 +1,4 @@ -//===- ScalarEvolutionPatternMatch.h - Match on SCEVs -----------*- C++ -*-===// -// +//===----------------------------------------------------------------------===// // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. // See https://llvm.org/LICENSE.txt for license information. // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception @@ -15,7 +14,6 @@ #define LLVM_ANALYSIS_SCALAREVOLUTIONPATTERNMATCH_H #include "llvm/Analysis/ScalarEvolutionExpressions.h" -#include "llvm/IR/PatternMatch.h" namespace llvm { namespace SCEVPatternMatch { @@ -25,19 +23,39 @@ bool match(const SCEV *S, const Pattern &P) { return P.match(S); } -struct specific_intval64 : public PatternMatch::specific_intval64 { - specific_intval64(uint64_t V) : PatternMatch::specific_intval64(V) {} - +template struct cst_pred_ty : public Predicate { bool match(const SCEV *S) { - auto *Cast = dyn_cast(S); - return Cast && - PatternMatch::specific_intval64::match(Cast->getValue()); + assert((isa(S) || !S->getType()->isVectorTy()) && + "no vector types expected from SCEVs"); + auto *C = dyn_cast(S); + return C && this->isValue(C->getAPInt()); + } +}; + +struct is_zero { + template bool match(ITy *S) { + assert((isa(S) || !S->getType()->isVectorTy()) && + "no vector types expected from SCEVs"); + auto *C = dyn_cast(S); + return C && C->getValue()->isNullValue(); } }; +/// Match any null constant. +inline is_zero m_scev_Zero() { return is_zero(); } + +struct is_one { + bool isValue(const APInt &C) { return C.isOne(); } +}; +/// Match an integer 1. +inline cst_pred_ty m_scev_One() { return cst_pred_ty(); } -inline specific_intval64 m_scev_Zero() { return specific_intval64(0); } -inline specific_intval64 m_scev_One() { return specific_intval64(1); } -inline specific_intval64 m_scev_MinusOne() { return specific_intval64(-1); } +struct is_all_ones { + bool isValue(const APInt &C) { return C.isAllOnes(); } +}; +/// Match an integer with all bits set. +inline cst_pred_ty m_scev_AllOnes() { + return cst_pred_ty(); +} } // namespace SCEVPatternMatch } // namespace llvm diff --git a/llvm/lib/Analysis/ScalarEvolution.cpp b/llvm/lib/Analysis/ScalarEvolution.cpp index 741431ac8aa15..e18133971f5bf 100644 --- a/llvm/lib/Analysis/ScalarEvolution.cpp +++ b/llvm/lib/Analysis/ScalarEvolution.cpp @@ -445,23 +445,11 @@ ArrayRef SCEV::operands() const { llvm_unreachable("Unknown SCEV kind!"); } -bool SCEV::isZero() const { - if (const SCEVConstant *SC = dyn_cast(this)) - return SC->getValue()->isZero(); - return false; -} +bool SCEV::isZero() const { return match(this, m_scev_Zero()); } -bool SCEV::isOne() const { - if (const SCEVConstant *SC = dyn_cast(this)) - return SC->getValue()->isOne(); - return false; -} +bool SCEV::isOne() const { return match(this, m_scev_One()); } -bool SCEV::isAllOnesValue() const { - if (const SCEVConstant *SC = dyn_cast(this)) - return SC->getValue()->isMinusOne(); - return false; -} +bool SCEV::isAllOnesValue() const { return match(this, m_scev_AllOnes()); } bool SCEV::isNonConstantNegative() const { const SCEVMulExpr *Mul = dyn_cast(this); @@ -10616,7 +10604,7 @@ ScalarEvolution::ExitLimit ScalarEvolution::howFarToZero(const SCEV *V, // 1*N = -Start; -1*N = Start (mod 2^BW), so: // N = Distance (as unsigned) - if (match(Step, m_CombineOr(m_scev_One(), m_scev_MinusOne()))) { + if (match(Step, m_CombineOr(m_scev_One(), m_scev_AllOnes()))) { APInt MaxBECount = getUnsignedRangeMax(applyLoopGuards(Distance, Guards)); MaxBECount = APIntOps::umin(MaxBECount, getUnsignedRangeMax(Distance)); @@ -15511,9 +15499,7 @@ void ScalarEvolution::LoopGuards::collectFromBlock( // If we have LHS == 0, check if LHS is computing a property of some unknown // SCEV %v which we can rewrite %v to express explicitly. - const SCEVConstant *RHSC = dyn_cast(RHS); - if (Predicate == CmpInst::ICMP_EQ && RHSC && - RHSC->getValue()->isNullValue()) { + if (Predicate == CmpInst::ICMP_EQ && match(RHS, m_scev_Zero())) { // If LHS is A % B, i.e. A % B == 0, rewrite A to (A /u B) * B to // explicitly express that. const SCEV *URemLHS = nullptr; @@ -15694,8 +15680,7 @@ void ScalarEvolution::LoopGuards::collectFromBlock( To = RHS; break; case CmpInst::ICMP_NE: - if (isa(RHS) && - cast(RHS)->getValue()->isNullValue()) { + if (match(RHS, m_scev_Zero())) { const SCEV *OneAlignedUp = DividesBy ? GetNextSCEVDividesByDivisor(One, DividesBy) : One; To = SE.getUMaxExpr(FromRewritten, OneAlignedUp); From af03d35afec4493d8ba0589ca942256b663e8e4b Mon Sep 17 00:00:00 2001 From: Florian Hahn Date: Wed, 11 Dec 2024 20:44:46 +0000 Subject: [PATCH 4/4] !fixup use cst_pred_ty for zero matcher --- .../llvm/Analysis/ScalarEvolutionPatternMatch.h | 11 +++-------- 1 file changed, 3 insertions(+), 8 deletions(-) diff --git a/llvm/include/llvm/Analysis/ScalarEvolutionPatternMatch.h b/llvm/include/llvm/Analysis/ScalarEvolutionPatternMatch.h index b4adfe2c2d8a4..21d2ef3c867d7 100644 --- a/llvm/include/llvm/Analysis/ScalarEvolutionPatternMatch.h +++ b/llvm/include/llvm/Analysis/ScalarEvolutionPatternMatch.h @@ -33,15 +33,10 @@ template struct cst_pred_ty : public Predicate { }; struct is_zero { - template bool match(ITy *S) { - assert((isa(S) || !S->getType()->isVectorTy()) && - "no vector types expected from SCEVs"); - auto *C = dyn_cast(S); - return C && C->getValue()->isNullValue(); - } + bool isValue(const APInt &C) { return C.isZero(); } }; -/// Match any null constant. -inline is_zero m_scev_Zero() { return is_zero(); } +/// Match an integer 0. +inline cst_pred_ty m_scev_Zero() { return cst_pred_ty(); } struct is_one { bool isValue(const APInt &C) { return C.isOne(); }