Skip to content

Conversation

@lipracer
Copy link
Member

@lipracer lipracer commented Sep 20, 2024

Fixes #107508

@lipracer lipracer changed the title [mlir][affine] fix the issue of celidiv mul childiv expression not sa… [mlir][affine] fix the issue of celidiv mul childiv expression not satisfying commutative Sep 20, 2024
@llvmbot
Copy link
Member

llvmbot commented Sep 20, 2024

@llvm/pr-subscribers-llvm-adt
@llvm/pr-subscribers-mlir

@llvm/pr-subscribers-mlir-affine

Author: long.chen (lipracer)

Changes

Fixs #107508


Full diff: https://github.com/llvm/llvm-project/pull/109382.diff

3 Files Affected:

  • (modified) llvm/include/llvm/ADT/ScopeExit.h (+6-5)
  • (modified) mlir/lib/IR/AffineExpr.cpp (+115-49)
  • (modified) mlir/test/Dialect/Affine/simplify-structures.mlir (+19-3)
diff --git a/llvm/include/llvm/ADT/ScopeExit.h b/llvm/include/llvm/ADT/ScopeExit.h
index 2f13fb65d34d80..7e126479df3a14 100644
--- a/llvm/include/llvm/ADT/ScopeExit.h
+++ b/llvm/include/llvm/ADT/ScopeExit.h
@@ -31,13 +31,14 @@ template <typename Callable> class scope_exit {
   template <typename Fp>
   explicit scope_exit(Fp &&F) : ExitFunction(std::forward<Fp>(F)) {}
 
-  scope_exit(scope_exit &&Rhs)
-      : ExitFunction(std::move(Rhs.ExitFunction)), Engaged(Rhs.Engaged) {
-    Rhs.release();
-  }
+  scope_exit(scope_exit &&Rhs) { *this = std::move(Rhs); }
   scope_exit(const scope_exit &) = delete;
-  scope_exit &operator=(scope_exit &&) = delete;
   scope_exit &operator=(const scope_exit &) = delete;
+  scope_exit &operator=(scope_exit &&Rhs) {
+    Engaged = std::exchange(Rhs.Engaged, false);
+    ExitFunction = std::exchange(Rhs.ExitFunction, {});
+    return *this;
+  }
 
   void release() { Engaged = false; }
 
diff --git a/mlir/lib/IR/AffineExpr.cpp b/mlir/lib/IR/AffineExpr.cpp
index fc7ede279643ed..84af1f11045d6d 100644
--- a/mlir/lib/IR/AffineExpr.cpp
+++ b/mlir/lib/IR/AffineExpr.cpp
@@ -9,6 +9,8 @@
 #include <cmath>
 #include <cstdint>
 #include <limits>
+#include <numeric>
+#include <optional>
 #include <utility>
 
 #include "AffineExprDetail.h"
@@ -18,9 +20,8 @@
 #include "mlir/IR/IntegerSet.h"
 #include "mlir/Support/TypeID.h"
 #include "llvm/ADT/STLExtras.h"
+#include "llvm/ADT/ScopeExit.h"
 #include "llvm/Support/MathExtras.h"
-#include <numeric>
-#include <optional>
 
 using namespace mlir;
 using namespace mlir::detail;
@@ -362,54 +363,119 @@ static bool isDivisibleBySymbol(AffineExpr expr, unsigned symbolPos,
   assert((opKind == AffineExprKind::Mod || opKind == AffineExprKind::FloorDiv ||
           opKind == AffineExprKind::CeilDiv) &&
          "unexpected opKind");
-  switch (expr.getKind()) {
-  case AffineExprKind::Constant:
-    return cast<AffineConstantExpr>(expr).getValue() == 0;
-  case AffineExprKind::DimId:
-    return false;
-  case AffineExprKind::SymbolId:
-    return (cast<AffineSymbolExpr>(expr).getPosition() == symbolPos);
-  // Checks divisibility by the given symbol for both operands.
-  case AffineExprKind::Add: {
-    AffineBinaryOpExpr binaryExpr = cast<AffineBinaryOpExpr>(expr);
-    return isDivisibleBySymbol(binaryExpr.getLHS(), symbolPos, opKind) &&
-           isDivisibleBySymbol(binaryExpr.getRHS(), symbolPos, opKind);
-  }
-  // Checks divisibility by the given symbol for both operands. Consider the
-  // expression `(((s1*s0) floordiv w) mod ((s1 * s2) floordiv p)) floordiv s1`,
-  // this is a division by s1 and both the operands of modulo are divisible by
-  // s1 but it is not divisible by s1 always. The third argument is
-  // `AffineExprKind::Mod` for this reason.
-  case AffineExprKind::Mod: {
-    AffineBinaryOpExpr binaryExpr = cast<AffineBinaryOpExpr>(expr);
-    return isDivisibleBySymbol(binaryExpr.getLHS(), symbolPos,
-                               AffineExprKind::Mod) &&
-           isDivisibleBySymbol(binaryExpr.getRHS(), symbolPos,
-                               AffineExprKind::Mod);
-  }
-  // Checks if any of the operand divisible by the given symbol.
-  case AffineExprKind::Mul: {
-    AffineBinaryOpExpr binaryExpr = cast<AffineBinaryOpExpr>(expr);
-    return isDivisibleBySymbol(binaryExpr.getLHS(), symbolPos, opKind) ||
-           isDivisibleBySymbol(binaryExpr.getRHS(), symbolPos, opKind);
-  }
-  // Floordiv and ceildiv are divisible by the given symbol when the first
-  // operand is divisible, and the affine expression kind of the argument expr
-  // is same as the argument `opKind`. This can be inferred from commutative
-  // property of floordiv and ceildiv operations and are as follow:
-  // (exp1 floordiv exp2) floordiv exp3 = (exp1 floordiv exp3) floordiv exp2
-  // (exp1 ceildiv exp2) ceildiv exp3 = (exp1 ceildiv exp3) ceildiv expr2
-  // It will fail if operations are not same. For example:
-  // (exps1 ceildiv exp2) floordiv exp3 can not be simplified.
-  case AffineExprKind::FloorDiv:
-  case AffineExprKind::CeilDiv: {
-    AffineBinaryOpExpr binaryExpr = cast<AffineBinaryOpExpr>(expr);
-    if (opKind != expr.getKind())
-      return false;
-    return isDivisibleBySymbol(binaryExpr.getLHS(), symbolPos, expr.getKind());
-  }
+  std::vector<std::tuple<AffineExpr, unsigned, AffineExprKind,
+                         llvm::detail::scope_exit<std::function<void(void)>>>>
+      stack;
+  stack.emplace_back(expr, symbolPos, opKind, []() {});
+  bool result = false;
+
+  while (!stack.empty()) {
+    AffineExpr expr = std::get<0>(stack.back());
+    unsigned symbolPos = std::get<1>(stack.back());
+    AffineExprKind opKind = std::get<2>(stack.back());
+
+    switch (expr.getKind()) {
+    case AffineExprKind::Constant: {
+      // Note: Assignment must occur before pop, which will affect whether it
+      // enters other execution branches.
+      result = cast<AffineConstantExpr>(expr).getValue() == 0;
+      stack.pop_back();
+      break;
+    }
+    case AffineExprKind::DimId: {
+      result = false;
+      stack.pop_back();
+      break;
+    }
+    case AffineExprKind::SymbolId: {
+      result = cast<AffineSymbolExpr>(expr).getPosition() == symbolPos;
+      stack.pop_back();
+      break;
+    }
+    // Checks divisibility by the given symbol for both operands.
+    case AffineExprKind::Add: {
+      AffineBinaryOpExpr binaryExpr = cast<AffineBinaryOpExpr>(expr);
+      stack.emplace_back(binaryExpr.getLHS(), symbolPos, opKind,
+                         [&stack, &result, binaryExpr, symbolPos, opKind]() {
+                           if (result) {
+                             stack.emplace_back(
+                                 binaryExpr.getRHS(), symbolPos, opKind,
+                                 [&stack]() { stack.pop_back(); });
+                           } else {
+                             stack.pop_back();
+                           }
+                         });
+      break;
+    }
+    // Checks divisibility by the given symbol for both operands. Consider the
+    // expression `(((s1*s0) floordiv w) mod ((s1 * s2) floordiv p)) floordiv
+    // s1`, this is a division by s1 and both the operands of modulo are
+    // divisible by s1 but it is not divisible by s1 always. The third argument
+    // is `AffineExprKind::Mod` for this reason.
+    case AffineExprKind::Mod: {
+      AffineBinaryOpExpr binaryExpr = cast<AffineBinaryOpExpr>(expr);
+      stack.emplace_back(binaryExpr.getLHS(), symbolPos, AffineExprKind::Mod,
+                         [&stack, &result, binaryExpr, symbolPos, opKind]() {
+                           if (result) {
+                             stack.emplace_back(
+                                 binaryExpr.getRHS(), symbolPos,
+                                 AffineExprKind::Mod,
+                                 [&stack]() { stack.pop_back(); });
+                           } else {
+                             stack.pop_back();
+                           }
+                         });
+      break;
+    }
+    // Checks if any of the operand divisible by the given symbol.
+    case AffineExprKind::Mul: {
+      AffineBinaryOpExpr binaryExpr = cast<AffineBinaryOpExpr>(expr);
+      stack.emplace_back(binaryExpr.getLHS(), symbolPos, opKind,
+                         [&stack, &result, binaryExpr, symbolPos, opKind]() {
+                           if (!result) {
+                             stack.emplace_back(
+                                 binaryExpr.getRHS(), symbolPos, opKind,
+                                 [&stack]() { stack.pop_back(); });
+                           } else {
+                             stack.pop_back();
+                           }
+                         });
+      break;
+    }
+    // Floordiv and ceildiv are divisible by the given symbol when the first
+    // operand is divisible, and the affine expression kind of the argument expr
+    // is same as the argument `opKind`. This can be inferred from commutative
+    // property of floordiv and ceildiv operations and are as follow:
+    // (exp1 floordiv exp2) floordiv exp3 = (exp1 floordiv exp3) floordiv exp2
+    // (exp1 ceildiv exp2) ceildiv exp3 = (exp1 ceildiv exp3) ceildiv expr2
+    // It will fail 1.if operations are not same. For example:
+    // (exps1 ceildiv exp2) floordiv exp3 can not be simplified. 2.if there is a
+    // multiplication operation in the expression. For example:
+    //  (exps1 ceildiv exp2) mul exp3 ceildiv exp4 can not be simplified.
+    case AffineExprKind::FloorDiv:
+    case AffineExprKind::CeilDiv: {
+      AffineBinaryOpExpr binaryExpr = cast<AffineBinaryOpExpr>(expr);
+      if (opKind != expr.getKind()) {
+        result = false;
+        stack.pop_back();
+        break;
+      }
+      if (llvm::any_of(stack, [](auto &it) {
+            return std::get<0>(it).getKind() == AffineExprKind::Mul;
+          })) {
+        result = false;
+        stack.pop_back();
+        break;
+      }
+
+      stack.emplace_back(binaryExpr.getLHS(), symbolPos, expr.getKind(),
+                         [&stack]() { stack.pop_back(); });
+      break;
+    }
+      llvm_unreachable("Unknown AffineExpr");
+    }
   }
-  llvm_unreachable("Unknown AffineExpr");
+  return result;
 }
 
 /// Divides the given expression by the given symbol at position `symbolPos`. It
diff --git a/mlir/test/Dialect/Affine/simplify-structures.mlir b/mlir/test/Dialect/Affine/simplify-structures.mlir
index 92d3d86bc93068..d1f34f20fa5dad 100644
--- a/mlir/test/Dialect/Affine/simplify-structures.mlir
+++ b/mlir/test/Dialect/Affine/simplify-structures.mlir
@@ -308,10 +308,26 @@ func.func @semiaffine_ceildiv(%arg0: index, %arg1: index) -> index {
 }
 
 // Tests the simplification of a semi-affine expression with a nested ceildiv operation and further simplifications after performing ceildiv.
-// CHECK-LABEL: func @semiaffine_composite_floor
-func.func @semiaffine_composite_floor(%arg0: index, %arg1: index) -> index {
+// CHECK-LABEL: func @semiaffine_composite_ceildiv
+func.func @semiaffine_composite_ceildiv(%arg0: index, %arg1: index) -> index {
+  %a = affine.apply affine_map<(d0)[s0] ->((((s0 * 2) ceildiv 4) + s0 * 42) ceildiv s0)> (%arg0)[%arg1]
+  // CHECK:       %[[CST:.*]] = arith.constant 43
+  return %a : index
+}
+
+// Tests the do not simplification of a semi-affine expression with a nested ceildiv-mul-ceildiv operation.
+// CHECK-LABEL: func @semiaffine_composite_ceildiv
+func.func @semiaffine_composite_ceildiv_mul_ceildiv(%arg0: index, %arg1: index) -> index {
   %a = affine.apply affine_map<(d0)[s0] ->(((((s0 * 2) ceildiv 4) * 5) + s0 * 42) ceildiv s0)> (%arg0)[%arg1]
-  // CHECK:       %[[CST:.*]] = arith.constant 47
+  // CHECK-NOT:       arith.constant
+  return %a : index
+}
+
+// Tests the do not simplification of a semi-affine expression with a nested floordiv_mul_floordiv operation
+// CHECK-LABEL: func @semiaffine_composite_floordiv
+func.func @semiaffine_composite_floordiv_mul_floordiv(%arg0: index, %arg1: index) -> index {
+  %a = affine.apply affine_map<(d0)[s0] ->(((((s0 * 2) floordiv 4) * 5) + s0 * 42) floordiv s0)> (%arg0)[%arg1]
+  // CHECK-NOT:       arith.constant
   return %a : index
 }
 

@lipracer lipracer force-pushed the fix-107508 branch 4 times, most recently from af8be8c to 21587f8 Compare September 20, 2024 10:42
@jreiffers
Copy link
Member

Hey, I see you're still actively changing this. Let me know when it's ready for review.

@lipracer
Copy link
Member Author

lipracer commented Sep 20, 2024

Hey, I see you're still actively changing this. Let me know when it's ready for review.

Ready for review now.

@lipracer lipracer changed the title [mlir][affine] fix the issue of celidiv mul childiv expression not satisfying commutative [mlir][affine] fix the issue of celidiv mul ceildiv expression not satisfying commutative Sep 22, 2024
@lipracer lipracer added the awaiting-review Has pending Phabricator review label Sep 23, 2024
@github-actions
Copy link

github-actions bot commented Sep 28, 2024

✅ With the latest revision this PR passed the C/C++ code formatter.

@lipracer lipracer requested review from jreiffers October 4, 2024 04:18
return static_cast<ImplType *>(expr)->position;
}

/// A manually managed stack used to convert recursive function calls into
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why not just use AffineExpr::walk? It'll do a post-order traversal, so you could just have a stack of partial results. Unless I'm still missing something :-).

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This requires pre order traversal. Perhaps we can refactor Affine Visitor to support pre order traversal.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you give me a small example/explanation where pre-order traversal is needed?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sorry,I didn't express clearly. The implementation of this isDivisibleBySymbol requires a pre-order traversal.
eg:

case add: visit(lhs) and vist(rhs)
case mul: visit(lhs) or vist(rhs)

base on the current expr type, using a control traversal approach, perhaps returning the interrupting and boolean types in the visitor can meet this requirement, but it also needs to be reimplemented.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In a post-order traversal with a stack for intermediates, you'd have the results for lhs and rhs on the top of the stack when you visit an add, right? Can you give me a more complete example where this doesn't work?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I see no problem with using an explicit stack in addition to the implicit one in ::walk.

If you're concerned about stack overflows, the correct thing to do is to refactor ::walk to not be recursive, since that function is used in many places, so fixing this particular location isn't really sufficient if stack overflows are an issue with ::walk.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Refactoring MLIR walk() to not be recursive is a long-lasting issue: patch welcome to fix this!

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ok.I will switch to a ‘walk’ approach and then refactor AffineExprVisitor in another patch using a stack to avoid recursive calls.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why don't use walk,
This is walk implementation.At this point,lhs and rhs have already been traversed.and I don't know where to do the push stack.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@joker-eph The walk here is the traversal method of AffineExpr. When you mention mlir:: walk, do you mean about mlir::Operation?

@lipracer lipracer closed this Oct 5, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

awaiting-review Has pending Phabricator review llvm:adt mlir:affine mlir:core MLIR Core Infrastructure mlir

Projects

None yet

Development

Successfully merging this pull request may close these issues.

[mlir] Semantic inconsistency in ceildiv optimization

4 participants