diff --git a/mlir/include/mlir/Dialect/Index/IR/IndexOps.td b/mlir/include/mlir/Dialect/Index/IR/IndexOps.td index a30ae9f739cbc..ce1355316b09b 100644 --- a/mlir/include/mlir/Dialect/Index/IR/IndexOps.td +++ b/mlir/include/mlir/Dialect/Index/IR/IndexOps.td @@ -56,6 +56,8 @@ def Index_AddOp : IndexBinaryOp<"add", [Commutative, Pure]> { %c = index.add %a, %b ``` }]; + + let hasCanonicalizeMethod = 1; } //===----------------------------------------------------------------------===// diff --git a/mlir/lib/Dialect/Index/IR/IndexOps.cpp b/mlir/lib/Dialect/Index/IR/IndexOps.cpp index 42401dae217ce..dbc63d9d10758 100644 --- a/mlir/lib/Dialect/Index/IR/IndexOps.cpp +++ b/mlir/lib/Dialect/Index/IR/IndexOps.cpp @@ -136,6 +136,28 @@ OpFoldResult AddOp::fold(FoldAdaptor adaptor) { return {}; } +/// Canonicalize +/// ` x = v + c1; y = x + c2` to `x = v + (c1 + c2)` +LogicalResult AddOp::canonicalize(AddOp op, PatternRewriter &rewriter) { + IntegerAttr c1, c2; + if (!mlir::matchPattern(op.getRhs(), mlir::m_Constant(&c1))) + return rewriter.notifyMatchFailure(op.getLoc(), "RHS is not a constant"); + + auto add = op.getLhs().getDefiningOp(); + if (!add) + return rewriter.notifyMatchFailure(op.getLoc(), "LHS is not a add"); + + if (!mlir::matchPattern(add.getRhs(), mlir::m_Constant(&c2))) + return rewriter.notifyMatchFailure(op.getLoc(), "RHS is not a constant"); + + auto c = rewriter.create(op->getLoc(), + c1.getInt() + c2.getInt()); + auto newAdd = + rewriter.create(op->getLoc(), add.getLhs(), c); + + rewriter.replaceOp(op, newAdd); + return success(); +} //===----------------------------------------------------------------------===// // SubOp diff --git a/mlir/test/Dialect/Index/index-canonicalize.mlir b/mlir/test/Dialect/Index/index-canonicalize.mlir index 37aa33bfde952..a29b09c11f7f6 100644 --- a/mlir/test/Dialect/Index/index-canonicalize.mlir +++ b/mlir/test/Dialect/Index/index-canonicalize.mlir @@ -32,6 +32,19 @@ func.func @add_overflow() -> (index, index) { return %2, %3 : index, index } +// CHECK-LABEL: @add +func.func @add_fold_constants(%arg: index) -> (index) { + %0 = index.constant 1 + %1 = index.constant 2 + %2 = index.add %arg, %0 + %3 = index.add %2, %1 + + // CHECK-DAG: [[C3:%.*]] = index.constant 3 + // CHECK-DAG: [[V0:%.*]] = index.add %arg0, [[C3]] + // CHECK: return [[V0]] + return %3 : index +} + // CHECK-LABEL: @sub func.func @sub() -> index { %0 = index.constant -2000000000