Skip to content

Conversation

@ziliangzl
Copy link

Add canonicalization patterns for nested min/max operations with constants, e.g.:

max(max(x, c0), c1) -> max(x, max(c0, c1))
min(min(x, c0), c1) -> min(x, min(c0, c1))

Patterns are added for arith.maxsi, arith.maxui, arith.minsi, and arith.minui.

Add canonicalization patterns for nested min/max operations with constants, e.g.:

  max(max(x, c0), c1) -> max(x, max(c0, c1))
  min(min(x, c0), c1) -> min(x, min(c0, c1))

Patterns are added for arith.maxsi, arith.maxui, arith.minsi, and arith.minui.
@github-actions
Copy link

Thank you for submitting a Pull Request (PR) to the LLVM Project!

This PR will be automatically labeled and the relevant teams will be notified.

If you wish to, you can add reviewers by using the "Reviewers" section on this page.

If this is not working for you, it is probably because you do not have write permissions for the repository. In which case you can instead tag reviewers by name in a comment by using @ followed by their GitHub username.

If you have received no comments on your PR for a week, you can request a review by "ping"ing the PR by adding a comment “Ping”. The common courtesy "ping" rate is once a week. Please remember that you are asking for valuable time from other developers.

If you have further questions, they may be answered by the LLVM GitHub User Guide.

You can also ask questions in a comment on this PR, on the LLVM Discord or on the forums.

@llvmbot
Copy link
Member

llvmbot commented Sep 28, 2025

@llvm/pr-subscribers-mlir-arith

@llvm/pr-subscribers-mlir

Author: Ziliang Zhang (ziliangzl)

Changes

Add canonicalization patterns for nested min/max operations with constants, e.g.:

max(max(x, c0), c1) -> max(x, max(c0, c1))
min(min(x, c0), c1) -> min(x, min(c0, c1))

Patterns are added for arith.maxsi, arith.maxui, arith.minsi, and arith.minui.


Full diff: https://github.com/llvm/llvm-project/pull/161057.diff

4 Files Affected:

  • (modified) mlir/include/mlir/Dialect/Arith/IR/ArithOps.td (+4)
  • (modified) mlir/lib/Dialect/Arith/IR/ArithCanonicalization.td (+68)
  • (modified) mlir/lib/Dialect/Arith/IR/ArithOps.cpp (+40)
  • (modified) mlir/test/Dialect/Arith/canonicalize.mlir (+96-1)
diff --git a/mlir/include/mlir/Dialect/Arith/IR/ArithOps.td b/mlir/include/mlir/Dialect/Arith/IR/ArithOps.td
index 20c9097b51e6d..739d0439c4bba 100644
--- a/mlir/include/mlir/Dialect/Arith/IR/ArithOps.td
+++ b/mlir/include/mlir/Dialect/Arith/IR/ArithOps.td
@@ -1008,6 +1008,7 @@ def Arith_MaxNumFOp : Arith_FloatBinaryOp<"maxnumf", [Commutative]> {
 def Arith_MaxSIOp : Arith_TotalIntBinaryOp<"maxsi", [Commutative]> {
   let summary = "signed integer maximum operation";
   let hasFolder = 1;
+  let hasCanonicalizer = 1;
 }
 
 //===----------------------------------------------------------------------===//
@@ -1017,6 +1018,7 @@ def Arith_MaxSIOp : Arith_TotalIntBinaryOp<"maxsi", [Commutative]> {
 def Arith_MaxUIOp : Arith_TotalIntBinaryOp<"maxui", [Commutative]> {
   let summary = "unsigned integer maximum operation";
   let hasFolder = 1;
+  let hasCanonicalizer = 1;
 }
 
 //===----------------------------------------------------------------------===//
@@ -1067,6 +1069,7 @@ def Arith_MinNumFOp : Arith_FloatBinaryOp<"minnumf", [Commutative]> {
 def Arith_MinSIOp : Arith_TotalIntBinaryOp<"minsi", [Commutative]> {
   let summary = "signed integer minimum operation";
   let hasFolder = 1;
+  let hasCanonicalizer = 1;
 }
 
 //===----------------------------------------------------------------------===//
@@ -1076,6 +1079,7 @@ def Arith_MinSIOp : Arith_TotalIntBinaryOp<"minsi", [Commutative]> {
 def Arith_MinUIOp : Arith_TotalIntBinaryOp<"minui", [Commutative]> {
   let summary = "unsigned integer minimum operation";
   let hasFolder = 1;
+  let hasCanonicalizer = 1;
 }
 
 
diff --git a/mlir/lib/Dialect/Arith/IR/ArithCanonicalization.td b/mlir/lib/Dialect/Arith/IR/ArithCanonicalization.td
index de3efc9fe3506..ef57af86f0540 100644
--- a/mlir/lib/Dialect/Arith/IR/ArithCanonicalization.td
+++ b/mlir/lib/Dialect/Arith/IR/ArithCanonicalization.td
@@ -24,6 +24,18 @@ def SubIntAttrs : NativeCodeCall<"subIntegerAttrs($_builder, $0, $1, $2)">;
 // Multiply two integer attributes and create a new one with the result.
 def MulIntAttrs : NativeCodeCall<"mulIntegerAttrs($_builder, $0, $1, $2)">;
 
+// Select signed min value of two integer attributes and store to the result
+def SMinIntAttrs : NativeCodeCall<"sminIntegerAttrs($_builder, $0, $1, $2)">;
+
+// Select unsigned min value of two integer attributes and store to the result
+def UMinIntAttrs : NativeCodeCall<"uminIntegerAttrs($_builder, $0, $1, $2)">;
+
+// Select signed max value of two integer attributes and store to the result
+def SMaxIntAttrs : NativeCodeCall<"smaxIntegerAttrs($_builder, $0, $1, $2)">;
+
+// Select unsigned max value of two integer attributes and store to the result
+def UMaxIntAttrs : NativeCodeCall<"umaxIntegerAttrs($_builder, $0, $1, $2)">;
+
 // Merge overflow flags from 2 ops, selecting the most conservative combination.
 def MergeOverflow : NativeCodeCall<"mergeOverflowFlags($0, $1)">;
 
@@ -202,6 +214,62 @@ def MulUIExtendedToMulI :
         [(Arith_MulIOp $x, $y, DefOverflow), (replaceWithValue $x)],
         [(Constraint<CPred<"$0.getUses().empty()">> $res__1)]>;
 
+//===----------------------------------------------------------------------===//
+// MaxSIOp
+//===----------------------------------------------------------------------===//
+
+// maxsi is commutative and will be canonicalized to have its constants appear
+// as the second operand.
+
+// maxsi(maxsi(x, c0), c1) -> maxsi(x, maxsi(c0, c1))
+def MaxSIMaxSIConstant :
+    Pat<(Arith_MaxSIOp:$res
+          (Arith_MaxSIOp $x, (ConstantLikeMatcher APIntAttr:$c0)),
+          (ConstantLikeMatcher APIntAttr:$c1)),
+        (Arith_MaxSIOp $x, (Arith_ConstantOp (SMaxIntAttrs $res, $c0, $c1)))>;
+
+//===----------------------------------------------------------------------===//
+// MaxUIOp
+//===----------------------------------------------------------------------===//
+
+// maxui is commutative and will be canonicalized to have its constants appear
+// as the second operand.
+
+// maxui(maxui(x, c0), c1) -> maxui(x, maxui(c0, c1))
+def MaxUIMaxUIConstant :
+    Pat<(Arith_MaxUIOp:$res
+          (Arith_MaxUIOp $x, (ConstantLikeMatcher APIntAttr:$c0)),
+          (ConstantLikeMatcher APIntAttr:$c1)),
+        (Arith_MaxUIOp $x, (Arith_ConstantOp (UMaxIntAttrs $res, $c0, $c1)))>;
+
+//===----------------------------------------------------------------------===//
+// MinSIOp
+//===----------------------------------------------------------------------===//
+
+// minsi is commutative and will be canonicalized to have its constants appear
+// as the second operand.
+
+// minsi(minsi(x, c0), c1) -> minsi(x, minsi(c0, c1))
+def MinSIMinSIConstant :
+    Pat<(Arith_MinSIOp:$res
+          (Arith_MinSIOp $x, (ConstantLikeMatcher APIntAttr:$c0)),
+          (ConstantLikeMatcher APIntAttr:$c1)),
+        (Arith_MinSIOp $x, (Arith_ConstantOp (SMinIntAttrs $res, $c0, $c1)))>;
+
+//===----------------------------------------------------------------------===//
+// MinUIOp
+//===----------------------------------------------------------------------===//
+
+// minui is commutative and will be canonicalized to have its constants appear
+// as the second operand.
+
+// minui(minui(x, c0), c1) -> minui(x, minui(c0, c1))
+def MinUIMinUIConstant :
+    Pat<(Arith_MinUIOp:$res
+          (Arith_MinUIOp $x, (ConstantLikeMatcher APIntAttr:$c0)),
+          (ConstantLikeMatcher APIntAttr:$c1)),
+        (Arith_MinUIOp $x, (Arith_ConstantOp (UMinIntAttrs $res, $c0, $c1)))>;
+
 //===----------------------------------------------------------------------===//
 // XOrIOp
 //===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Dialect/Arith/IR/ArithOps.cpp b/mlir/lib/Dialect/Arith/IR/ArithOps.cpp
index 7cfd6d3a98df8..82270ab64f7ec 100644
--- a/mlir/lib/Dialect/Arith/IR/ArithOps.cpp
+++ b/mlir/lib/Dialect/Arith/IR/ArithOps.cpp
@@ -63,6 +63,26 @@ static IntegerAttr mulIntegerAttrs(PatternRewriter &builder, Value res,
   return applyToIntegerAttrs(builder, res, lhs, rhs, std::multiplies<APInt>());
 }
 
+static IntegerAttr sminIntegerAttrs(PatternRewriter &builder, Value res,
+                                    Attribute lhs, Attribute rhs) {
+  return applyToIntegerAttrs(builder, res, lhs, rhs, llvm::APIntOps::smin);
+}
+
+static IntegerAttr uminIntegerAttrs(PatternRewriter &builder, Value res,
+                                    Attribute lhs, Attribute rhs) {
+  return applyToIntegerAttrs(builder, res, lhs, rhs, llvm::APIntOps::umin);
+}
+
+static IntegerAttr smaxIntegerAttrs(PatternRewriter &builder, Value res,
+                                    Attribute lhs, Attribute rhs) {
+  return applyToIntegerAttrs(builder, res, lhs, rhs, llvm::APIntOps::smax);
+}
+
+static IntegerAttr umaxIntegerAttrs(PatternRewriter &builder, Value res,
+                                    Attribute lhs, Attribute rhs) {
+  return applyToIntegerAttrs(builder, res, lhs, rhs, llvm::APIntOps::umax);
+}
+
 // Merge overflow flags from 2 ops, selecting the most conservative combination.
 static IntegerOverflowFlagsAttr
 mergeOverflowFlags(IntegerOverflowFlagsAttr val1,
@@ -1162,6 +1182,11 @@ OpFoldResult MaxSIOp::fold(FoldAdaptor adaptor) {
                                         });
 }
 
+void arith::MaxSIOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
+                                                 MLIRContext *context) {
+  patterns.add<MaxSIMaxSIConstant>(context);
+}
+
 //===----------------------------------------------------------------------===//
 // MaxUIOp
 //===----------------------------------------------------------------------===//
@@ -1187,6 +1212,11 @@ OpFoldResult MaxUIOp::fold(FoldAdaptor adaptor) {
                                         });
 }
 
+void arith::MaxUIOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
+                                                 MLIRContext *context) {
+  patterns.add<MaxUIMaxUIConstant>(context);
+}
+
 //===----------------------------------------------------------------------===//
 // MinimumFOp
 //===----------------------------------------------------------------------===//
@@ -1248,6 +1278,11 @@ OpFoldResult MinSIOp::fold(FoldAdaptor adaptor) {
                                         });
 }
 
+void arith::MinSIOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
+                                                 MLIRContext *context) {
+  patterns.add<MinSIMinSIConstant>(context);
+}
+
 //===----------------------------------------------------------------------===//
 // MinUIOp
 //===----------------------------------------------------------------------===//
@@ -1273,6 +1308,11 @@ OpFoldResult MinUIOp::fold(FoldAdaptor adaptor) {
                                         });
 }
 
+void arith::MinUIOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
+                                                 MLIRContext *context) {
+  patterns.add<MinUIMinUIConstant>(context);
+}
+
 //===----------------------------------------------------------------------===//
 // MulFOp
 //===----------------------------------------------------------------------===//
diff --git a/mlir/test/Dialect/Arith/canonicalize.mlir b/mlir/test/Dialect/Arith/canonicalize.mlir
index ca3de3a2d7703..1848decc2eb7c 100644
--- a/mlir/test/Dialect/Arith/canonicalize.mlir
+++ b/mlir/test/Dialect/Arith/canonicalize.mlir
@@ -1952,6 +1952,30 @@ func.func @bitcastChain(%arg: i16) -> f16 {
 
 // -----
 
+// CHECK-LABEL: @maxsiMaxsiConst1
+//       CHECK:   %[[C42:.+]] = arith.constant 42 : i32
+//       CHECK:   %[[RES:.+]] = arith.maxsi %arg0, %[[C42]] : i32
+//       CHECK:   return %[[RES]]
+func.func @maxsiMaxsiConst1(%arg0: i32) -> i32 {
+  %c17 = arith.constant 17 : i32
+  %c42 = arith.constant 42 : i32
+  %max1 = arith.maxsi %arg0, %c17 : i32
+  %max2 = arith.maxsi %max1, %c42 : i32
+  return %max2 : i32
+}
+
+// CHECK-LABEL: @maxsiMaxsiConst2
+//       CHECK:   %[[C21:.+]] = arith.constant 21 : i32
+//       CHECK:   %[[RES:.+]] = arith.maxsi %arg0, %[[C21]] : i32
+//       CHECK:   return %[[RES]]
+func.func @maxsiMaxsiConst2(%arg0: i32) -> i32 {
+  %c7 = arith.constant 7 : i32
+  %c21 = arith.constant 21 : i32
+  %max1 = arith.maxsi %arg0, %c7 : i32
+  %max2 = arith.maxsi %c21, %max1 : i32
+  return %max2 : i32
+}
+
 // CHECK-LABEL: test_maxsi
 // CHECK-DAG: %[[C0:.+]] = arith.constant 42
 // CHECK-DAG: %[[MAX_INT_CST:.+]] = arith.constant 127
@@ -1986,6 +2010,30 @@ func.func @test_maxsi2(%arg0 : i8) -> (i8, i8, i8, i8) {
 
 // -----
 
+// CHECK-LABEL: @maxuiMaxuiConst1
+//       CHECK:   %[[C42:.+]] = arith.constant 42 : index
+//       CHECK:   %[[RES:.+]] = arith.maxui %arg0, %[[C42]] : index
+//       CHECK:   return %[[RES]]
+func.func @maxuiMaxuiConst1(%arg0: index) -> index {
+  %c17 = arith.constant 17 : index
+  %c42 = arith.constant 42 : index
+  %max1 = arith.maxui %arg0, %c17 : index
+  %max2 = arith.maxui %max1, %c42 : index
+  return %max2 : index
+}
+
+// CHECK-LABEL: @maxuiMaxuiConst2
+//       CHECK:   %[[C21:.+]] = arith.constant 21 : index
+//       CHECK:   %[[RES:.+]] = arith.maxui %arg0, %[[C21]] : index
+//       CHECK:   return %[[RES]]
+func.func @maxuiMaxuiConst2(%arg0: index) -> index {
+  %c7 = arith.constant 7 : index
+  %c21 = arith.constant 21 : index
+  %max1 = arith.maxui %arg0, %c7 : index
+  %max2 = arith.maxui %c21, %max1 : index
+  return %max2 : index
+}
+
 // CHECK-LABEL: test_maxui
 // CHECK-DAG: %[[C0:.+]] = arith.constant 42
 // CHECK-DAG: %[[MAX_INT_CST:.+]] = arith.constant -1
@@ -2020,6 +2068,30 @@ func.func @test_maxui2(%arg0 : i8) -> (i8, i8, i8, i8) {
 
 // -----
 
+// CHECK-LABEL: @minsiMinsiConst1
+//       CHECK:   %[[C17:.+]] = arith.constant 17 : i32
+//       CHECK:   %[[RES:.+]] = arith.minsi %arg0, %[[C17]] : i32
+//       CHECK:   return %[[RES]]
+func.func @minsiMinsiConst1(%arg0: i32) -> i32 {
+  %c17 = arith.constant 17 : i32
+  %c42 = arith.constant 42 : i32
+  %min1 = arith.minsi %arg0, %c17 : i32
+  %min2 = arith.minsi %min1, %c42 : i32
+  return %min2 : i32
+}
+
+// CHECK-LABEL: @minsiMinsiConst2
+//       CHECK:   %[[C7:.+]] = arith.constant 7 : i32
+//       CHECK:   %[[RES:.+]] = arith.minsi %arg0, %[[C7]] : i32
+//       CHECK:   return %[[RES]]
+func.func @minsiMinsiConst2(%arg0: i32) -> i32 {
+  %c7 = arith.constant 7 : i32
+  %c21 = arith.constant 21 : i32
+  %min1 = arith.minsi %arg0, %c7 : i32
+  %min2 = arith.minsi %c21, %min1 : i32
+  return %min2 : i32
+}
+
 // CHECK-LABEL: test_minsi
 // CHECK-DAG: %[[C0:.+]] = arith.constant 42
 // CHECK-DAG: %[[MIN_INT_CST:.+]] = arith.constant -128
@@ -2054,6 +2126,30 @@ func.func @test_minsi2(%arg0 : i8) -> (i8, i8, i8, i8) {
 
 // -----
 
+// CHECK-LABEL: @minuiMinuiConst1
+//       CHECK:   %[[C17:.+]] = arith.constant 17 : index
+//       CHECK:   %[[RES:.+]] = arith.minui %arg0, %[[C17]] : index
+//       CHECK:   return %[[RES]]
+func.func @minuiMinuiConst1(%arg0: index) -> index {
+  %c17 = arith.constant 17 : index
+  %c42 = arith.constant 42 : index
+  %min1 = arith.minui %arg0, %c17 : index
+  %min2 = arith.minui %min1, %c42 : index
+  return %min2 : index
+}
+
+// CHECK-LABEL: @minuiMinuiConst2
+//       CHECK:   %[[C7:.+]] = arith.constant 7 : index
+//       CHECK:   %[[RES:.+]] = arith.minui %arg0, %[[C7]] : index
+//       CHECK:   return %[[RES]]
+func.func @minuiMinuiConst2(%arg0: index) -> index {
+  %c7 = arith.constant 7 : index
+  %c21 = arith.constant 21 : index
+  %min1 = arith.minui %arg0, %c7 : index
+  %min2 = arith.minui %c21, %min1 : index
+  return %min2 : index
+}
+
 // CHECK-LABEL: test_minui
 // CHECK-DAG: %[[C0:.+]] = arith.constant 42
 // CHECK-DAG: %[[MIN_INT_CST:.+]] = arith.constant 0
@@ -3377,4 +3473,3 @@ func.func @unreachable() {
   %add = arith.addi %add, %c1_i64 : i64
   cf.br ^unreachable
 }
-

@joker-eph
Copy link
Collaborator

Looks to me like these could be implemented as folder instead?

@joker-eph
Copy link
Collaborator

Also, isn't this something that we should generically do in the folder for the IsCommutative trait?

@ziliangzl
Copy link
Author

Looks to me like these could be implemented as folder instead?

I just followed the same approach as AddIOp. Maybe implementing this as a TableGen pattern could simplify the code?

@joker-eph
Copy link
Collaborator

I just followed the same approach as AddIOp

These should likely be moved to a folder then.

Maybe implementing this as a TableGen pattern could simplify the code?

Can you use DRR pattern to implement folders?

More importantly here:

Also, isn't this something that we should generically do in the folder for the IsCommutative trait?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants