Skip to content

Conversation

@fhahn
Copy link
Contributor

@fhahn fhahn commented Dec 10, 2024

This patch adds initial matchers for SCEV expressions with an arbitrary
number of operands and specializes it for binary add expressions.

Also adds matchers for SCEVConstant and SCEVUnknown.

This patch only converts a few instances to use the new matchers to make
sure everything builds as expected for now.

The goal of the matchers is to hopefully make it slightly easier to write
code matching SCEV patterns.

Depends on #119389

@fhahn fhahn requested a review from nikic as a code owner December 10, 2024 15:12
@llvmbot llvmbot added the llvm:analysis Includes value tracking, cost tables and constant folding label Dec 10, 2024
@llvmbot
Copy link
Member

llvmbot commented Dec 10, 2024

@llvm/pr-subscribers-llvm-analysis

Author: Florian Hahn (fhahn)

Changes

This patch adds initial matchers for SCEV expressions with an arbitrary
number of operands and specializes it for binary add expressions.

Also adds matchers for SCEVConstant and SCEVUnknown.

This patch only converts a few instances to use the new matchers to make
sure everything builds as expected for now.

The goal of the matchers is to hopefully make it slightly easier to write
code matching SCEV patterns.

Depends on #119389


Full diff: https://github.com/llvm/llvm-project/pull/119390.diff

2 Files Affected:

  • (added) llvm/include/llvm/Analysis/ScalarEvolutionPatternMatch.h (+140)
  • (modified) llvm/lib/Analysis/ScalarEvolution.cpp (+12-13)
diff --git a/llvm/include/llvm/Analysis/ScalarEvolutionPatternMatch.h b/llvm/include/llvm/Analysis/ScalarEvolutionPatternMatch.h
new file mode 100644
index 00000000000000..1dff1a672fd225
--- /dev/null
+++ b/llvm/include/llvm/Analysis/ScalarEvolutionPatternMatch.h
@@ -0,0 +1,140 @@
+//===- ScalarEvolutionPatternMatch.h - Match on SCEVs -----------*- C++ -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This file provides a simple and efficient mechanism for performing general
+// tree-based pattern matches on SCEVs, based on LLVM's IR pattern matchers.
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef LLVM_ANALYSIS_SCALAREVOLUTIONPATTERNMATCH_H
+#define LLVM_ANALYSIS_SCALAREVOLUTIONPATTERNMATCH_H
+
+#include "llvm/Analysis/ScalarEvolutionExpressions.h"
+
+namespace llvm {
+namespace SCEVPatternMatch {
+
+template <typename Val, typename Pattern>
+bool match(const SCEV *S, const Pattern &P) {
+  return P.match(S);
+}
+
+/// Match a specified integer value. \p BitWidth optionally specifies the
+/// bitwidth the matched constant must have. If it is 0, the matched constant
+/// can have any bitwidth.
+template <unsigned BitWidth = 0> struct specific_intval {
+  APInt Val;
+
+  specific_intval(APInt V) : Val(std::move(V)) {}
+
+  bool match(const SCEV *S) const {
+    const auto *C = dyn_cast<SCEVConstant>(S);
+    if (!C)
+      return false;
+
+    if (BitWidth != 0 && C->getAPInt().getBitWidth() != BitWidth)
+      return false;
+    return APInt::isSameValue(C->getAPInt(), Val);
+  }
+};
+
+inline specific_intval<0> m_scev_Zero() {
+  return specific_intval<0>(APInt(64, 0));
+}
+inline specific_intval<0> m_scev_One() {
+  return specific_intval<0>(APInt(64, 1));
+}
+inline specific_intval<0> m_scev_MinusOne() {
+  return specific_intval<0>(APInt(64, -1));
+}
+
+template <typename Class> struct class_match {
+  template <typename ITy> bool match(ITy *V) const { return isa<Class>(V); }
+};
+
+template <typename Class> struct bind_ty {
+  Class *&VR;
+
+  bind_ty(Class *&V) : VR(V) {}
+
+  template <typename ITy> bool match(ITy *V) const {
+    if (auto *CV = dyn_cast<Class>(V)) {
+      VR = CV;
+      return true;
+    }
+    return false;
+  }
+};
+
+/// Match a SCEV, capturing it if we match.
+inline bind_ty<const SCEV> m_SCEV(const SCEV *&V) { return V; }
+inline bind_ty<const SCEVConstant> m_SCEVConstant(const SCEVConstant *&V) {
+  return V;
+}
+inline bind_ty<const SCEVUnknown> m_SCEVUnknown(const SCEVUnknown *&V) {
+  return V;
+}
+
+namespace detail {
+
+template <typename TupleTy, typename Fn, std::size_t... Is>
+bool CheckTupleElements(const TupleTy &Ops, Fn P, std::index_sequence<Is...>) {
+  return (P(std::get<Is>(Ops), Is) && ...);
+}
+
+/// Helper to check if predicate \p P holds on all tuple elements in \p Ops
+template <typename TupleTy, typename Fn>
+bool all_of_tuple_elements(const TupleTy &Ops, Fn P) {
+  return CheckTupleElements(
+      Ops, P, std::make_index_sequence<std::tuple_size<TupleTy>::value>{});
+}
+
+} // namespace detail
+
+template <typename Ops_t, typename SCEVTy> struct SCEV_match {
+  Ops_t Ops;
+
+  SCEV_match() : Ops() {
+    static_assert(std::tuple_size<Ops_t>::value == 0 &&
+                  "constructor can only be used with zero operands");
+  }
+  SCEV_match(Ops_t Ops) : Ops(Ops) {}
+  template <typename A_t, typename B_t> SCEV_match(A_t A, B_t B) : Ops({A, B}) {
+    static_assert(std::tuple_size<Ops_t>::value == 2 &&
+                  "constructor can only be used for binary matcher");
+  }
+
+  bool match(const SCEV *S) const {
+    auto *Cast = dyn_cast<SCEVTy>(S);
+    if (!Cast || Cast->getNumOperands() != std::tuple_size<Ops_t>::value)
+      return false;
+    return detail::all_of_tuple_elements(Ops, [Cast](auto Op, unsigned Idx) {
+      return Op.match(Cast->getOperand(Idx));
+    });
+  }
+};
+
+template <typename Op0_t, typename Op1_t, typename SCEVTy>
+using BinarySCEV_match = SCEV_match<std::tuple<Op0_t, Op1_t>, SCEVTy>;
+
+template <typename Op0_t, typename Op1_t, typename SCEVTy>
+inline BinarySCEV_match<Op0_t, Op1_t, SCEVTy> m_scev_Binary(const Op0_t &Op0,
+                                                            const Op1_t &Op1) {
+  return BinarySCEV_match<Op0_t, Op1_t, SCEVTy>(Op0, Op1);
+}
+
+template <typename Op0_t, typename Op1_t>
+inline BinarySCEV_match<Op0_t, Op1_t, SCEVAddExpr>
+m_scev_Add(const Op0_t &Op0, const Op1_t &Op1) {
+  return BinarySCEV_match<Op0_t, Op1_t, SCEVAddExpr>(Op0, Op1);
+}
+
+} // namespace SCEVPatternMatch
+} // namespace llvm
+
+#endif
diff --git a/llvm/lib/Analysis/ScalarEvolution.cpp b/llvm/lib/Analysis/ScalarEvolution.cpp
index cad10486cbf3fa..ed92335606ad90 100644
--- a/llvm/lib/Analysis/ScalarEvolution.cpp
+++ b/llvm/lib/Analysis/ScalarEvolution.cpp
@@ -79,6 +79,7 @@
 #include "llvm/Analysis/LoopInfo.h"
 #include "llvm/Analysis/MemoryBuiltins.h"
 #include "llvm/Analysis/ScalarEvolutionExpressions.h"
+#include "llvm/Analysis/ScalarEvolutionPatternMatch.h"
 #include "llvm/Analysis/TargetLibraryInfo.h"
 #include "llvm/Analysis/ValueTracking.h"
 #include "llvm/Config/llvm-config.h"
@@ -133,6 +134,7 @@
 
 using namespace llvm;
 using namespace PatternMatch;
+using namespace SCEVPatternMatch;
 
 #define DEBUG_TYPE "scalar-evolution"
 
@@ -3423,9 +3425,8 @@ const SCEV *ScalarEvolution::getUDivExpr(const SCEV *LHS,
     return S;
 
   // 0 udiv Y == 0
-  if (const SCEVConstant *LHSC = dyn_cast<SCEVConstant>(LHS))
-    if (LHSC->getValue()->isZero())
-      return LHS;
+  if (match(LHS, m_scev_Zero()))
+    return LHS;
 
   if (const SCEVConstant *RHSC = dyn_cast<SCEVConstant>(RHS)) {
     if (RHSC->getValue()->isOne())
@@ -10593,7 +10594,6 @@ ScalarEvolution::ExitLimit ScalarEvolution::howFarToZero(const SCEV *V,
   // Get the initial value for the loop.
   const SCEV *Start = getSCEVAtScope(AddRec->getStart(), L->getParentLoop());
   const SCEV *Step = getSCEVAtScope(AddRec->getOperand(1), L->getParentLoop());
-  const SCEVConstant *StepC = dyn_cast<SCEVConstant>(Step);
 
   if (!isLoopInvariant(Step, L))
     return getCouldNotCompute();
@@ -10615,8 +10615,8 @@ ScalarEvolution::ExitLimit ScalarEvolution::howFarToZero(const SCEV *V,
   // Handle unitary steps, which cannot wraparound.
   // 1*N = -Start; -1*N = Start (mod 2^BW), so:
   //   N = Distance (as unsigned)
-  if (StepC &&
-      (StepC->getValue()->isOne() || StepC->getValue()->isMinusOne())) {
+
+  if (match(Step, m_CombineOr(m_scev_One(), m_scev_MinusOne()))) {
     APInt MaxBECount = getUnsignedRangeMax(applyLoopGuards(Distance, Guards));
     MaxBECount = APIntOps::umin(MaxBECount, getUnsignedRangeMax(Distance));
 
@@ -10668,6 +10668,7 @@ ScalarEvolution::ExitLimit ScalarEvolution::howFarToZero(const SCEV *V,
   }
 
   // Solve the general equation.
+  const SCEVConstant *StepC = dyn_cast<SCEVConstant>(Step);
   if (!StepC || StepC->getValue()->isZero())
     return getCouldNotCompute();
   const SCEV *E = SolveLinEquationWithOverflow(
@@ -15392,14 +15393,12 @@ void ScalarEvolution::LoopGuards::collectFromBlock(
     // (X >=u C1).
     auto MatchRangeCheckIdiom = [&SE, Predicate, LHS, RHS, &RewriteMap,
                                  &ExprsToRewrite]() {
-      auto *AddExpr = dyn_cast<SCEVAddExpr>(LHS);
-      if (!AddExpr || AddExpr->getNumOperands() != 2)
-        return false;
-
-      auto *C1 = dyn_cast<SCEVConstant>(AddExpr->getOperand(0));
-      auto *LHSUnknown = dyn_cast<SCEVUnknown>(AddExpr->getOperand(1));
+      const SCEVConstant *C1;
+      const SCEVUnknown *LHSUnknown;
       auto *C2 = dyn_cast<SCEVConstant>(RHS);
-      if (!C1 || !C2 || !LHSUnknown)
+      if (!match(LHS,
+                 m_scev_Add(m_SCEVConstant(C1), m_SCEVUnknown(LHSUnknown))) ||
+          !C2)
         return false;
 
       auto ExactRegion =

@nikic
Copy link
Contributor

nikic commented Dec 13, 2024

Needs a rebase.

@fhahn fhahn force-pushed the scev-patternmatch branch from ac5edf5 to a0dc428 Compare December 13, 2024 14:48
@fhahn
Copy link
Contributor Author

fhahn commented Dec 13, 2024

Rebased after #119389

Also added some unary matchers to make use of the flexibility of SCEVExpr_match

Copy link
Contributor

@nikic nikic left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This looks reasonable, but I can't claim to understand the template magic going on here. I do wonder if it wouldn't be better to just have separate unary and binary implementations and not try to be overly generic.


template <typename Op0_t, typename Op1_t, typename SCEVTy>
inline BinarySCEVExpr_match<Op0_t, Op1_t, SCEVTy>
m_scev_Binary(const Op0_t &Op0, const Op1_t &Op1) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

How are you supposed to actually use this one? Don't you have to make SCEVTy the first template parameter to make this ergnomic?

using UnarySCEVExpr_match = SCEVExpr_match<std::tuple<Op0_t>, SCEVTy>;

template <typename Op0_t, typename Op1_t, typename SCEVTy>
inline UnarySCEVExpr_match<Op0_t, SCEVTy> m_scev_Unary(const Op0_t &Op0) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same question here.

@fhahn
Copy link
Contributor Author

fhahn commented Dec 15, 2024

This looks reasonable, but I can't claim to understand the template magic going on here. I do wonder if it wouldn't be better to just have separate unary and binary implementations and not try to be overly generic.

Most of the complexity is coming from a missing utility to iterate over all members of a tuple. Happy to change it to a non-generic version. My original thought was that it might be convenient to match SCEV expressions with arbitrary number of arguments, but that would probably not be very common.

This patch adds initial matchers for SCEV expressions with an arbitrary
number of operands and specializes it for binary add expressions.

Also adds matchers for SCEVConstant and SCEVUnknown.

This patch only converts a few instances to use the new matchers to make
sure everything builds as expected for now.

Depends on llvm#119389
@fhahn fhahn force-pushed the scev-patternmatch branch from c78ce79 to a71a973 Compare December 17, 2024 10:46
@fhahn
Copy link
Contributor Author

fhahn commented Dec 17, 2024

Updated to have dedicated matchers for unary and binary exprs

Copy link
Contributor

@nikic nikic left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

@fhahn fhahn merged commit 8ea9576 into llvm:main Dec 17, 2024
8 checks passed
@fhahn fhahn deleted the scev-patternmatch branch December 17, 2024 12:13
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

llvm:analysis Includes value tracking, cost tables and constant folding

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants