Skip to content

Conversation

@weiweichen
Copy link
Contributor

  • Add a simple canonicalization for mlir::index::AddOp.

@llvmbot
Copy link
Member

llvmbot commented Oct 4, 2024

@llvm/pr-subscribers-mlir-index

Author: weiwei chen (weiweichen)

Changes
  • Add a simple canonicalization for mlir::index::AddOp.

Full diff: https://github.com/llvm/llvm-project/pull/111084.diff

3 Files Affected:

  • (modified) mlir/include/mlir/Dialect/Index/IR/IndexOps.td (+2)
  • (modified) mlir/lib/Dialect/Index/IR/IndexOps.cpp (+39)
  • (modified) mlir/test/Dialect/Index/index-canonicalize.mlir (+15)
diff --git a/mlir/include/mlir/Dialect/Index/IR/IndexOps.td b/mlir/include/mlir/Dialect/Index/IR/IndexOps.td
index a30ae9f739cbc6..ce1355316b09b8 100644
--- a/mlir/include/mlir/Dialect/Index/IR/IndexOps.td
+++ b/mlir/include/mlir/Dialect/Index/IR/IndexOps.td
@@ -56,6 +56,8 @@ def Index_AddOp : IndexBinaryOp<"add", [Commutative, Pure]> {
     %c = index.add %a, %b
     ```
   }];
+
+  let hasCanonicalizeMethod = 1;
 }
 
 //===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Dialect/Index/IR/IndexOps.cpp b/mlir/lib/Dialect/Index/IR/IndexOps.cpp
index 42401dae217ce1..ace9b43014a665 100644
--- a/mlir/lib/Dialect/Index/IR/IndexOps.cpp
+++ b/mlir/lib/Dialect/Index/IR/IndexOps.cpp
@@ -136,6 +136,45 @@ OpFoldResult AddOp::fold(FoldAdaptor adaptor) {
 
   return {};
 }
+/// Canonicalize
+/// ` x = v + c1; y = x + c2` to `x = v + (c1 + c2)`
+/// ` x = v + c1; y = c2 + x` to `x = v + (c1 + c2)`
+/// ` x = c1 + v; y = x + c2` to `x = v + (c1 + c2)`
+/// ` x = c1 + v; y = c2 + x` to `x = v + (c1 + c2)`
+LogicalResult AddOp::canonicalize(AddOp op, PatternRewriter &rewriter) {
+
+  auto matchConstant = [](mlir::index::AddOp op, Value &v, IntegerAttr &c) {
+    v = op.getLhs();
+    if (!mlir::matchPattern(op.getRhs(), mlir::m_Constant(&c))) {
+      v = op.getRhs();
+      if (!mlir::matchPattern(op.getLhs(), mlir::m_Constant(&c)))
+        return false;
+    }
+    return true;
+  };
+
+  IntegerAttr c1, c2;
+  Value v1, v2;
+
+  if (!matchConstant(op, v1, c1))
+    return rewriter.notifyMatchFailure(op.getLoc(),
+                                       "neither LHS nor RHS is constant");
+
+  auto add = v1.getDefiningOp<mlir::index::AddOp>();
+  if (!add)
+    return rewriter.notifyMatchFailure(op.getLoc(), "LHS is not a add");
+
+  if (!matchConstant(add, v2, c2))
+    return rewriter.notifyMatchFailure(op.getLoc(),
+                                       "neither LHS nor RHS is constant");
+
+  auto c = rewriter.create<mlir::index::ConstantOp>(op->getLoc(),
+                                                    c1.getInt() + c2.getInt());
+  auto newAdd = rewriter.create<mlir::index::AddOp>(op->getLoc(), v2, c);
+
+  rewriter.replaceOp(op, newAdd);
+  return success();
+}
 
 //===----------------------------------------------------------------------===//
 // SubOp
diff --git a/mlir/test/Dialect/Index/index-canonicalize.mlir b/mlir/test/Dialect/Index/index-canonicalize.mlir
index 37aa33bfde952e..256e327e83ea9c 100644
--- a/mlir/test/Dialect/Index/index-canonicalize.mlir
+++ b/mlir/test/Dialect/Index/index-canonicalize.mlir
@@ -32,6 +32,21 @@ func.func @add_overflow() -> (index, index) {
   return %2, %3 : index, index
 }
 
+// CHECK-LABEL: @add
+func.func @add_fold_constants(%arg: index) -> (index) {
+  %0 = index.constant 1
+  %1 = index.constant 2
+  %2 = index.add %arg, %0
+  %3 = index.add %1, %2
+  %4 = index.add %3, %1
+  %5 = index.add %4, %0
+
+  // CHECK-DAG: [[A:%.*]] = index.constant 6
+  // CHECK-DAG: [[V0:%.*]] = index.add %arg0, [[A]]
+  // CHECK: return [[V0]]
+  return %5 : index
+}
+
 // CHECK-LABEL: @sub
 func.func @sub() -> index {
   %0 = index.constant -2000000000

@llvmbot
Copy link
Member

llvmbot commented Oct 4, 2024

@llvm/pr-subscribers-mlir

Author: weiwei chen (weiweichen)

Changes
  • Add a simple canonicalization for mlir::index::AddOp.

Full diff: https://github.com/llvm/llvm-project/pull/111084.diff

3 Files Affected:

  • (modified) mlir/include/mlir/Dialect/Index/IR/IndexOps.td (+2)
  • (modified) mlir/lib/Dialect/Index/IR/IndexOps.cpp (+39)
  • (modified) mlir/test/Dialect/Index/index-canonicalize.mlir (+15)
diff --git a/mlir/include/mlir/Dialect/Index/IR/IndexOps.td b/mlir/include/mlir/Dialect/Index/IR/IndexOps.td
index a30ae9f739cbc6..ce1355316b09b8 100644
--- a/mlir/include/mlir/Dialect/Index/IR/IndexOps.td
+++ b/mlir/include/mlir/Dialect/Index/IR/IndexOps.td
@@ -56,6 +56,8 @@ def Index_AddOp : IndexBinaryOp<"add", [Commutative, Pure]> {
     %c = index.add %a, %b
     ```
   }];
+
+  let hasCanonicalizeMethod = 1;
 }
 
 //===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Dialect/Index/IR/IndexOps.cpp b/mlir/lib/Dialect/Index/IR/IndexOps.cpp
index 42401dae217ce1..ace9b43014a665 100644
--- a/mlir/lib/Dialect/Index/IR/IndexOps.cpp
+++ b/mlir/lib/Dialect/Index/IR/IndexOps.cpp
@@ -136,6 +136,45 @@ OpFoldResult AddOp::fold(FoldAdaptor adaptor) {
 
   return {};
 }
+/// Canonicalize
+/// ` x = v + c1; y = x + c2` to `x = v + (c1 + c2)`
+/// ` x = v + c1; y = c2 + x` to `x = v + (c1 + c2)`
+/// ` x = c1 + v; y = x + c2` to `x = v + (c1 + c2)`
+/// ` x = c1 + v; y = c2 + x` to `x = v + (c1 + c2)`
+LogicalResult AddOp::canonicalize(AddOp op, PatternRewriter &rewriter) {
+
+  auto matchConstant = [](mlir::index::AddOp op, Value &v, IntegerAttr &c) {
+    v = op.getLhs();
+    if (!mlir::matchPattern(op.getRhs(), mlir::m_Constant(&c))) {
+      v = op.getRhs();
+      if (!mlir::matchPattern(op.getLhs(), mlir::m_Constant(&c)))
+        return false;
+    }
+    return true;
+  };
+
+  IntegerAttr c1, c2;
+  Value v1, v2;
+
+  if (!matchConstant(op, v1, c1))
+    return rewriter.notifyMatchFailure(op.getLoc(),
+                                       "neither LHS nor RHS is constant");
+
+  auto add = v1.getDefiningOp<mlir::index::AddOp>();
+  if (!add)
+    return rewriter.notifyMatchFailure(op.getLoc(), "LHS is not a add");
+
+  if (!matchConstant(add, v2, c2))
+    return rewriter.notifyMatchFailure(op.getLoc(),
+                                       "neither LHS nor RHS is constant");
+
+  auto c = rewriter.create<mlir::index::ConstantOp>(op->getLoc(),
+                                                    c1.getInt() + c2.getInt());
+  auto newAdd = rewriter.create<mlir::index::AddOp>(op->getLoc(), v2, c);
+
+  rewriter.replaceOp(op, newAdd);
+  return success();
+}
 
 //===----------------------------------------------------------------------===//
 // SubOp
diff --git a/mlir/test/Dialect/Index/index-canonicalize.mlir b/mlir/test/Dialect/Index/index-canonicalize.mlir
index 37aa33bfde952e..256e327e83ea9c 100644
--- a/mlir/test/Dialect/Index/index-canonicalize.mlir
+++ b/mlir/test/Dialect/Index/index-canonicalize.mlir
@@ -32,6 +32,21 @@ func.func @add_overflow() -> (index, index) {
   return %2, %3 : index, index
 }
 
+// CHECK-LABEL: @add
+func.func @add_fold_constants(%arg: index) -> (index) {
+  %0 = index.constant 1
+  %1 = index.constant 2
+  %2 = index.add %arg, %0
+  %3 = index.add %1, %2
+  %4 = index.add %3, %1
+  %5 = index.add %4, %0
+
+  // CHECK-DAG: [[A:%.*]] = index.constant 6
+  // CHECK-DAG: [[V0:%.*]] = index.add %arg0, [[A]]
+  // CHECK: return [[V0]]
+  return %5 : index
+}
+
 // CHECK-LABEL: @sub
 func.func @sub() -> index {
   %0 = index.constant -2000000000

return {};
}
/// Canonicalize
/// ` x = v + c1; y = x + c2` to `x = v + (c1 + c2)`
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is the only permutation of the pattern that you actually need to implement. The canonicalizer will make sure that constants for commutative ops are always on the RHS.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

oh, good to know, PR updated!

@weiweichen
Copy link
Contributor Author

@Mogball another 👀 please 🙏 ?

@weiweichen weiweichen merged commit 7191ced into llvm:main Oct 22, 2024
9 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants