-
Notifications
You must be signed in to change notification settings - Fork 15.1k
[mlir][arith] Add canonicalize pattern for max/min with constants #161057
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
Add canonicalization patterns for nested min/max operations with constants, e.g.: max(max(x, c0), c1) -> max(x, max(c0, c1)) min(min(x, c0), c1) -> min(x, min(c0, c1)) Patterns are added for arith.maxsi, arith.maxui, arith.minsi, and arith.minui.
|
Thank you for submitting a Pull Request (PR) to the LLVM Project! This PR will be automatically labeled and the relevant teams will be notified. If you wish to, you can add reviewers by using the "Reviewers" section on this page. If this is not working for you, it is probably because you do not have write permissions for the repository. In which case you can instead tag reviewers by name in a comment by using If you have received no comments on your PR for a week, you can request a review by "ping"ing the PR by adding a comment “Ping”. The common courtesy "ping" rate is once a week. Please remember that you are asking for valuable time from other developers. If you have further questions, they may be answered by the LLVM GitHub User Guide. You can also ask questions in a comment on this PR, on the LLVM Discord or on the forums. |
|
@llvm/pr-subscribers-mlir-arith @llvm/pr-subscribers-mlir Author: Ziliang Zhang (ziliangzl) ChangesAdd canonicalization patterns for nested min/max operations with constants, e.g.: max(max(x, c0), c1) -> max(x, max(c0, c1)) Patterns are added for arith.maxsi, arith.maxui, arith.minsi, and arith.minui. Full diff: https://github.com/llvm/llvm-project/pull/161057.diff 4 Files Affected:
diff --git a/mlir/include/mlir/Dialect/Arith/IR/ArithOps.td b/mlir/include/mlir/Dialect/Arith/IR/ArithOps.td
index 20c9097b51e6d..739d0439c4bba 100644
--- a/mlir/include/mlir/Dialect/Arith/IR/ArithOps.td
+++ b/mlir/include/mlir/Dialect/Arith/IR/ArithOps.td
@@ -1008,6 +1008,7 @@ def Arith_MaxNumFOp : Arith_FloatBinaryOp<"maxnumf", [Commutative]> {
def Arith_MaxSIOp : Arith_TotalIntBinaryOp<"maxsi", [Commutative]> {
let summary = "signed integer maximum operation";
let hasFolder = 1;
+ let hasCanonicalizer = 1;
}
//===----------------------------------------------------------------------===//
@@ -1017,6 +1018,7 @@ def Arith_MaxSIOp : Arith_TotalIntBinaryOp<"maxsi", [Commutative]> {
def Arith_MaxUIOp : Arith_TotalIntBinaryOp<"maxui", [Commutative]> {
let summary = "unsigned integer maximum operation";
let hasFolder = 1;
+ let hasCanonicalizer = 1;
}
//===----------------------------------------------------------------------===//
@@ -1067,6 +1069,7 @@ def Arith_MinNumFOp : Arith_FloatBinaryOp<"minnumf", [Commutative]> {
def Arith_MinSIOp : Arith_TotalIntBinaryOp<"minsi", [Commutative]> {
let summary = "signed integer minimum operation";
let hasFolder = 1;
+ let hasCanonicalizer = 1;
}
//===----------------------------------------------------------------------===//
@@ -1076,6 +1079,7 @@ def Arith_MinSIOp : Arith_TotalIntBinaryOp<"minsi", [Commutative]> {
def Arith_MinUIOp : Arith_TotalIntBinaryOp<"minui", [Commutative]> {
let summary = "unsigned integer minimum operation";
let hasFolder = 1;
+ let hasCanonicalizer = 1;
}
diff --git a/mlir/lib/Dialect/Arith/IR/ArithCanonicalization.td b/mlir/lib/Dialect/Arith/IR/ArithCanonicalization.td
index de3efc9fe3506..ef57af86f0540 100644
--- a/mlir/lib/Dialect/Arith/IR/ArithCanonicalization.td
+++ b/mlir/lib/Dialect/Arith/IR/ArithCanonicalization.td
@@ -24,6 +24,18 @@ def SubIntAttrs : NativeCodeCall<"subIntegerAttrs($_builder, $0, $1, $2)">;
// Multiply two integer attributes and create a new one with the result.
def MulIntAttrs : NativeCodeCall<"mulIntegerAttrs($_builder, $0, $1, $2)">;
+// Select signed min value of two integer attributes and store to the result
+def SMinIntAttrs : NativeCodeCall<"sminIntegerAttrs($_builder, $0, $1, $2)">;
+
+// Select unsigned min value of two integer attributes and store to the result
+def UMinIntAttrs : NativeCodeCall<"uminIntegerAttrs($_builder, $0, $1, $2)">;
+
+// Select signed max value of two integer attributes and store to the result
+def SMaxIntAttrs : NativeCodeCall<"smaxIntegerAttrs($_builder, $0, $1, $2)">;
+
+// Select unsigned max value of two integer attributes and store to the result
+def UMaxIntAttrs : NativeCodeCall<"umaxIntegerAttrs($_builder, $0, $1, $2)">;
+
// Merge overflow flags from 2 ops, selecting the most conservative combination.
def MergeOverflow : NativeCodeCall<"mergeOverflowFlags($0, $1)">;
@@ -202,6 +214,62 @@ def MulUIExtendedToMulI :
[(Arith_MulIOp $x, $y, DefOverflow), (replaceWithValue $x)],
[(Constraint<CPred<"$0.getUses().empty()">> $res__1)]>;
+//===----------------------------------------------------------------------===//
+// MaxSIOp
+//===----------------------------------------------------------------------===//
+
+// maxsi is commutative and will be canonicalized to have its constants appear
+// as the second operand.
+
+// maxsi(maxsi(x, c0), c1) -> maxsi(x, maxsi(c0, c1))
+def MaxSIMaxSIConstant :
+ Pat<(Arith_MaxSIOp:$res
+ (Arith_MaxSIOp $x, (ConstantLikeMatcher APIntAttr:$c0)),
+ (ConstantLikeMatcher APIntAttr:$c1)),
+ (Arith_MaxSIOp $x, (Arith_ConstantOp (SMaxIntAttrs $res, $c0, $c1)))>;
+
+//===----------------------------------------------------------------------===//
+// MaxUIOp
+//===----------------------------------------------------------------------===//
+
+// maxui is commutative and will be canonicalized to have its constants appear
+// as the second operand.
+
+// maxui(maxui(x, c0), c1) -> maxui(x, maxui(c0, c1))
+def MaxUIMaxUIConstant :
+ Pat<(Arith_MaxUIOp:$res
+ (Arith_MaxUIOp $x, (ConstantLikeMatcher APIntAttr:$c0)),
+ (ConstantLikeMatcher APIntAttr:$c1)),
+ (Arith_MaxUIOp $x, (Arith_ConstantOp (UMaxIntAttrs $res, $c0, $c1)))>;
+
+//===----------------------------------------------------------------------===//
+// MinSIOp
+//===----------------------------------------------------------------------===//
+
+// minsi is commutative and will be canonicalized to have its constants appear
+// as the second operand.
+
+// minsi(minsi(x, c0), c1) -> minsi(x, minsi(c0, c1))
+def MinSIMinSIConstant :
+ Pat<(Arith_MinSIOp:$res
+ (Arith_MinSIOp $x, (ConstantLikeMatcher APIntAttr:$c0)),
+ (ConstantLikeMatcher APIntAttr:$c1)),
+ (Arith_MinSIOp $x, (Arith_ConstantOp (SMinIntAttrs $res, $c0, $c1)))>;
+
+//===----------------------------------------------------------------------===//
+// MinUIOp
+//===----------------------------------------------------------------------===//
+
+// minui is commutative and will be canonicalized to have its constants appear
+// as the second operand.
+
+// minui(minui(x, c0), c1) -> minui(x, minui(c0, c1))
+def MinUIMinUIConstant :
+ Pat<(Arith_MinUIOp:$res
+ (Arith_MinUIOp $x, (ConstantLikeMatcher APIntAttr:$c0)),
+ (ConstantLikeMatcher APIntAttr:$c1)),
+ (Arith_MinUIOp $x, (Arith_ConstantOp (UMinIntAttrs $res, $c0, $c1)))>;
+
//===----------------------------------------------------------------------===//
// XOrIOp
//===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Dialect/Arith/IR/ArithOps.cpp b/mlir/lib/Dialect/Arith/IR/ArithOps.cpp
index 7cfd6d3a98df8..82270ab64f7ec 100644
--- a/mlir/lib/Dialect/Arith/IR/ArithOps.cpp
+++ b/mlir/lib/Dialect/Arith/IR/ArithOps.cpp
@@ -63,6 +63,26 @@ static IntegerAttr mulIntegerAttrs(PatternRewriter &builder, Value res,
return applyToIntegerAttrs(builder, res, lhs, rhs, std::multiplies<APInt>());
}
+static IntegerAttr sminIntegerAttrs(PatternRewriter &builder, Value res,
+ Attribute lhs, Attribute rhs) {
+ return applyToIntegerAttrs(builder, res, lhs, rhs, llvm::APIntOps::smin);
+}
+
+static IntegerAttr uminIntegerAttrs(PatternRewriter &builder, Value res,
+ Attribute lhs, Attribute rhs) {
+ return applyToIntegerAttrs(builder, res, lhs, rhs, llvm::APIntOps::umin);
+}
+
+static IntegerAttr smaxIntegerAttrs(PatternRewriter &builder, Value res,
+ Attribute lhs, Attribute rhs) {
+ return applyToIntegerAttrs(builder, res, lhs, rhs, llvm::APIntOps::smax);
+}
+
+static IntegerAttr umaxIntegerAttrs(PatternRewriter &builder, Value res,
+ Attribute lhs, Attribute rhs) {
+ return applyToIntegerAttrs(builder, res, lhs, rhs, llvm::APIntOps::umax);
+}
+
// Merge overflow flags from 2 ops, selecting the most conservative combination.
static IntegerOverflowFlagsAttr
mergeOverflowFlags(IntegerOverflowFlagsAttr val1,
@@ -1162,6 +1182,11 @@ OpFoldResult MaxSIOp::fold(FoldAdaptor adaptor) {
});
}
+void arith::MaxSIOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
+ MLIRContext *context) {
+ patterns.add<MaxSIMaxSIConstant>(context);
+}
+
//===----------------------------------------------------------------------===//
// MaxUIOp
//===----------------------------------------------------------------------===//
@@ -1187,6 +1212,11 @@ OpFoldResult MaxUIOp::fold(FoldAdaptor adaptor) {
});
}
+void arith::MaxUIOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
+ MLIRContext *context) {
+ patterns.add<MaxUIMaxUIConstant>(context);
+}
+
//===----------------------------------------------------------------------===//
// MinimumFOp
//===----------------------------------------------------------------------===//
@@ -1248,6 +1278,11 @@ OpFoldResult MinSIOp::fold(FoldAdaptor adaptor) {
});
}
+void arith::MinSIOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
+ MLIRContext *context) {
+ patterns.add<MinSIMinSIConstant>(context);
+}
+
//===----------------------------------------------------------------------===//
// MinUIOp
//===----------------------------------------------------------------------===//
@@ -1273,6 +1308,11 @@ OpFoldResult MinUIOp::fold(FoldAdaptor adaptor) {
});
}
+void arith::MinUIOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
+ MLIRContext *context) {
+ patterns.add<MinUIMinUIConstant>(context);
+}
+
//===----------------------------------------------------------------------===//
// MulFOp
//===----------------------------------------------------------------------===//
diff --git a/mlir/test/Dialect/Arith/canonicalize.mlir b/mlir/test/Dialect/Arith/canonicalize.mlir
index ca3de3a2d7703..1848decc2eb7c 100644
--- a/mlir/test/Dialect/Arith/canonicalize.mlir
+++ b/mlir/test/Dialect/Arith/canonicalize.mlir
@@ -1952,6 +1952,30 @@ func.func @bitcastChain(%arg: i16) -> f16 {
// -----
+// CHECK-LABEL: @maxsiMaxsiConst1
+// CHECK: %[[C42:.+]] = arith.constant 42 : i32
+// CHECK: %[[RES:.+]] = arith.maxsi %arg0, %[[C42]] : i32
+// CHECK: return %[[RES]]
+func.func @maxsiMaxsiConst1(%arg0: i32) -> i32 {
+ %c17 = arith.constant 17 : i32
+ %c42 = arith.constant 42 : i32
+ %max1 = arith.maxsi %arg0, %c17 : i32
+ %max2 = arith.maxsi %max1, %c42 : i32
+ return %max2 : i32
+}
+
+// CHECK-LABEL: @maxsiMaxsiConst2
+// CHECK: %[[C21:.+]] = arith.constant 21 : i32
+// CHECK: %[[RES:.+]] = arith.maxsi %arg0, %[[C21]] : i32
+// CHECK: return %[[RES]]
+func.func @maxsiMaxsiConst2(%arg0: i32) -> i32 {
+ %c7 = arith.constant 7 : i32
+ %c21 = arith.constant 21 : i32
+ %max1 = arith.maxsi %arg0, %c7 : i32
+ %max2 = arith.maxsi %c21, %max1 : i32
+ return %max2 : i32
+}
+
// CHECK-LABEL: test_maxsi
// CHECK-DAG: %[[C0:.+]] = arith.constant 42
// CHECK-DAG: %[[MAX_INT_CST:.+]] = arith.constant 127
@@ -1986,6 +2010,30 @@ func.func @test_maxsi2(%arg0 : i8) -> (i8, i8, i8, i8) {
// -----
+// CHECK-LABEL: @maxuiMaxuiConst1
+// CHECK: %[[C42:.+]] = arith.constant 42 : index
+// CHECK: %[[RES:.+]] = arith.maxui %arg0, %[[C42]] : index
+// CHECK: return %[[RES]]
+func.func @maxuiMaxuiConst1(%arg0: index) -> index {
+ %c17 = arith.constant 17 : index
+ %c42 = arith.constant 42 : index
+ %max1 = arith.maxui %arg0, %c17 : index
+ %max2 = arith.maxui %max1, %c42 : index
+ return %max2 : index
+}
+
+// CHECK-LABEL: @maxuiMaxuiConst2
+// CHECK: %[[C21:.+]] = arith.constant 21 : index
+// CHECK: %[[RES:.+]] = arith.maxui %arg0, %[[C21]] : index
+// CHECK: return %[[RES]]
+func.func @maxuiMaxuiConst2(%arg0: index) -> index {
+ %c7 = arith.constant 7 : index
+ %c21 = arith.constant 21 : index
+ %max1 = arith.maxui %arg0, %c7 : index
+ %max2 = arith.maxui %c21, %max1 : index
+ return %max2 : index
+}
+
// CHECK-LABEL: test_maxui
// CHECK-DAG: %[[C0:.+]] = arith.constant 42
// CHECK-DAG: %[[MAX_INT_CST:.+]] = arith.constant -1
@@ -2020,6 +2068,30 @@ func.func @test_maxui2(%arg0 : i8) -> (i8, i8, i8, i8) {
// -----
+// CHECK-LABEL: @minsiMinsiConst1
+// CHECK: %[[C17:.+]] = arith.constant 17 : i32
+// CHECK: %[[RES:.+]] = arith.minsi %arg0, %[[C17]] : i32
+// CHECK: return %[[RES]]
+func.func @minsiMinsiConst1(%arg0: i32) -> i32 {
+ %c17 = arith.constant 17 : i32
+ %c42 = arith.constant 42 : i32
+ %min1 = arith.minsi %arg0, %c17 : i32
+ %min2 = arith.minsi %min1, %c42 : i32
+ return %min2 : i32
+}
+
+// CHECK-LABEL: @minsiMinsiConst2
+// CHECK: %[[C7:.+]] = arith.constant 7 : i32
+// CHECK: %[[RES:.+]] = arith.minsi %arg0, %[[C7]] : i32
+// CHECK: return %[[RES]]
+func.func @minsiMinsiConst2(%arg0: i32) -> i32 {
+ %c7 = arith.constant 7 : i32
+ %c21 = arith.constant 21 : i32
+ %min1 = arith.minsi %arg0, %c7 : i32
+ %min2 = arith.minsi %c21, %min1 : i32
+ return %min2 : i32
+}
+
// CHECK-LABEL: test_minsi
// CHECK-DAG: %[[C0:.+]] = arith.constant 42
// CHECK-DAG: %[[MIN_INT_CST:.+]] = arith.constant -128
@@ -2054,6 +2126,30 @@ func.func @test_minsi2(%arg0 : i8) -> (i8, i8, i8, i8) {
// -----
+// CHECK-LABEL: @minuiMinuiConst1
+// CHECK: %[[C17:.+]] = arith.constant 17 : index
+// CHECK: %[[RES:.+]] = arith.minui %arg0, %[[C17]] : index
+// CHECK: return %[[RES]]
+func.func @minuiMinuiConst1(%arg0: index) -> index {
+ %c17 = arith.constant 17 : index
+ %c42 = arith.constant 42 : index
+ %min1 = arith.minui %arg0, %c17 : index
+ %min2 = arith.minui %min1, %c42 : index
+ return %min2 : index
+}
+
+// CHECK-LABEL: @minuiMinuiConst2
+// CHECK: %[[C7:.+]] = arith.constant 7 : index
+// CHECK: %[[RES:.+]] = arith.minui %arg0, %[[C7]] : index
+// CHECK: return %[[RES]]
+func.func @minuiMinuiConst2(%arg0: index) -> index {
+ %c7 = arith.constant 7 : index
+ %c21 = arith.constant 21 : index
+ %min1 = arith.minui %arg0, %c7 : index
+ %min2 = arith.minui %c21, %min1 : index
+ return %min2 : index
+}
+
// CHECK-LABEL: test_minui
// CHECK-DAG: %[[C0:.+]] = arith.constant 42
// CHECK-DAG: %[[MIN_INT_CST:.+]] = arith.constant 0
@@ -3377,4 +3473,3 @@ func.func @unreachable() {
%add = arith.addi %add, %c1_i64 : i64
cf.br ^unreachable
}
-
|
|
Looks to me like these could be implemented as folder instead? |
|
Also, isn't this something that we should generically do in the folder for the IsCommutative trait? |
I just followed the same approach as AddIOp. Maybe implementing this as a TableGen pattern could simplify the code? |
These should likely be moved to a folder then.
Can you use DRR pattern to implement folders? More importantly here:
|
Add canonicalization patterns for nested min/max operations with constants, e.g.:
max(max(x, c0), c1) -> max(x, max(c0, c1))
min(min(x, c0), c1) -> min(x, min(c0, c1))
Patterns are added for arith.maxsi, arith.maxui, arith.minsi, and arith.minui.