Skip to content

Conversation

@ziliangzl
Copy link

@ziliangzl ziliangzl commented Sep 23, 2025

Supported folding for arith.maxsi, arith.maxui, arith.minsi, arith.minui, arith.maxnumf, arith.minnumf, arith.maximumf, arith,minimumf.

  1. Fold redundant consecutive min/max operations:
    max(max(a, b), b) -> max(a, b)
    max(max(a, b), a) -> max(a, b)
    max(a, max(a, b)) -> max(a, b)
    max(b, max(a, b)) -> max(a, b)
    (similar cases for min)

  2. Fold using the absorption law:
    max(min(a, b), a) -> a
    max(min(b, a), a) -> a
    max(a, min(a, b)) -> a
    max(a, min(b, a)) -> a
    (similar cases for min; not applicable to floating-point)

…secutive ops

Supported folding for arith.maxsi, arith.maxui, arith.minsi, and arith.minui.

1. Fold redundant consecutive min/max operations:
  max(max(a, b), b) -> max(a, b)
  max(max(a, b), a) -> max(a, b)
  max(a, max(a, b)) -> max(a, b)
  max(b, max(a, b)) -> max(a, b)
  (similar cases for min)

2. Fold using the absorption law:
  max(min(a, b), a) -> a
  max(min(b, a), a) -> a
  max(a, min(a, b)) -> a
  max(a, min(b, a)) -> a
  (similar cases for min)
@github-actions
Copy link

Thank you for submitting a Pull Request (PR) to the LLVM Project!

This PR will be automatically labeled and the relevant teams will be notified.

If you wish to, you can add reviewers by using the "Reviewers" section on this page.

If this is not working for you, it is probably because you do not have write permissions for the repository. In which case you can instead tag reviewers by name in a comment by using @ followed by their GitHub username.

If you have received no comments on your PR for a week, you can request a review by "ping"ing the PR by adding a comment “Ping”. The common courtesy "ping" rate is once a week. Please remember that you are asking for valuable time from other developers.

If you have further questions, they may be answered by the LLVM GitHub User Guide.

You can also ask questions in a comment on this PR, on the LLVM Discord or on the forums.

@llvmbot
Copy link
Member

llvmbot commented Sep 23, 2025

@llvm/pr-subscribers-mlir-arith

@llvm/pr-subscribers-mlir

Author: Ziliang Zhang (ziliangzl)

Changes

…secutive ops

Supported folding for arith.maxsi, arith.maxui, arith.minsi, and arith.minui.

  1. Fold redundant consecutive min/max operations: max(max(a, b), b) -> max(a, b) max(max(a, b), a) -> max(a, b) max(a, max(a, b)) -> max(a, b) max(b, max(a, b)) -> max(a, b) (similar cases for min)

  2. Fold using the absorption law: max(min(a, b), a) -> a max(min(b, a), a) -> a max(a, min(a, b)) -> a max(a, min(b, a)) -> a (similar cases for min)


Full diff: https://github.com/llvm/llvm-project/pull/160224.diff

2 Files Affected:

  • (modified) mlir/lib/Dialect/Arith/IR/ArithOps.cpp (+96)
  • (modified) mlir/test/Dialect/Arith/canonicalize.mlir (+136-1)
diff --git a/mlir/lib/Dialect/Arith/IR/ArithOps.cpp b/mlir/lib/Dialect/Arith/IR/ArithOps.cpp
index 7cfd6d3a98df8..ea95d15b96f0c 100644
--- a/mlir/lib/Dialect/Arith/IR/ArithOps.cpp
+++ b/mlir/lib/Dialect/Arith/IR/ArithOps.cpp
@@ -1156,6 +1156,30 @@ OpFoldResult MaxSIOp::fold(FoldAdaptor adaptor) {
       return getLhs();
   }
 
+  // max(max(a, b), b) -> max(a, b)
+  // max(max(a, b), a) -> max(a, b)
+  if (auto max = getLhs().getDefiningOp<MaxSIOp>())
+    if (getRhs() == max.getRhs() || getRhs() == max.getLhs())
+      return getLhs();
+
+  // max(a, max(a, b)) -> max(a, b)
+  // max(b, max(a, b)) -> max(a, b)
+  if (auto max = getRhs().getDefiningOp<MaxSIOp>())
+    if (getLhs() == max.getRhs() || getLhs() == max.getLhs())
+      return getRhs();
+
+  // max(min(a, b), a) -> a
+  // max(min(b, a), a) -> a
+  if (auto min = getLhs().getDefiningOp<MinSIOp>())
+    if (getRhs() == min.getRhs() || getRhs() == min.getLhs())
+      return getRhs();
+
+  // max(a, min(a, b)) -> a
+  // max(a, min(b, a)) -> a
+  if (auto min = getRhs().getDefiningOp<MinSIOp>())
+    if (getLhs() == min.getRhs() || getLhs() == min.getLhs())
+      return getLhs();
+
   return constFoldBinaryOp<IntegerAttr>(adaptor.getOperands(),
                                         [](const APInt &a, const APInt &b) {
                                           return llvm::APIntOps::smax(a, b);
@@ -1181,6 +1205,30 @@ OpFoldResult MaxUIOp::fold(FoldAdaptor adaptor) {
       return getLhs();
   }
 
+  // max(max(a, b), b) -> max(a, b)
+  // max(max(a, b), a) -> max(a, b)
+  if (auto max = getLhs().getDefiningOp<MaxUIOp>())
+    if (getRhs() == max.getRhs() || getRhs() == max.getLhs())
+      return getLhs();
+
+  // max(a, max(a, b)) -> max(a, b)
+  // max(b, max(a, b)) -> max(a, b)
+  if (auto max = getRhs().getDefiningOp<MaxUIOp>())
+    if (getLhs() == max.getRhs() || getLhs() == max.getLhs())
+      return getRhs();
+
+  // max(min(a, b), a) -> a
+  // max(min(b, a), a) -> a
+  if (auto min = getLhs().getDefiningOp<MinUIOp>())
+    if (getRhs() == min.getRhs() || getRhs() == min.getLhs())
+      return getRhs();
+
+  // max(a, min(a, b)) -> a
+  // max(a, min(b, a)) -> a
+  if (auto min = getRhs().getDefiningOp<MinUIOp>())
+    if (getLhs() == min.getRhs() || getLhs() == min.getLhs())
+      return getLhs();
+
   return constFoldBinaryOp<IntegerAttr>(adaptor.getOperands(),
                                         [](const APInt &a, const APInt &b) {
                                           return llvm::APIntOps::umax(a, b);
@@ -1242,6 +1290,30 @@ OpFoldResult MinSIOp::fold(FoldAdaptor adaptor) {
       return getLhs();
   }
 
+  // min(min(a, b), b) -> min(a, b)
+  // min(min(a, b), a) -> min(a, b)
+  if (auto min = getLhs().getDefiningOp<MinSIOp>())
+    if (getRhs() == min.getRhs() || getRhs() == min.getLhs())
+      return getLhs();
+
+  // min(a, min(a, b)) -> min(a, b)
+  // min(b, min(a, b)) -> min(a, b)
+  if (auto min = getRhs().getDefiningOp<MinSIOp>())
+    if (getLhs() == min.getRhs() || getLhs() == min.getLhs())
+      return getRhs();
+
+  // min(max(a, b), a) -> a
+  // min(max(b, a), a) -> a
+  if (auto max = getLhs().getDefiningOp<MaxSIOp>())
+    if (getRhs() == max.getRhs() || getRhs() == max.getLhs())
+      return getRhs();
+
+  // min(a, max(a, b)) -> a
+  // min(a, max(b, a)) -> a
+  if (auto max = getRhs().getDefiningOp<MaxSIOp>())
+    if (getLhs() == max.getRhs() || getLhs() == max.getLhs())
+      return getLhs();
+
   return constFoldBinaryOp<IntegerAttr>(adaptor.getOperands(),
                                         [](const APInt &a, const APInt &b) {
                                           return llvm::APIntOps::smin(a, b);
@@ -1267,6 +1339,30 @@ OpFoldResult MinUIOp::fold(FoldAdaptor adaptor) {
       return getLhs();
   }
 
+  // min(min(a, b), b) -> min(a, b)
+  // min(min(a, b), a) -> min(a, b)
+  if (auto min = getLhs().getDefiningOp<MinUIOp>())
+    if (getRhs() == min.getRhs() || getRhs() == min.getLhs())
+      return getLhs();
+
+  // min(a, min(a, b)) -> min(a, b)
+  // min(b, min(a, b)) -> min(a, b)
+  if (auto min = getRhs().getDefiningOp<MinUIOp>())
+    if (getLhs() == min.getRhs() || getLhs() == min.getLhs())
+      return getRhs();
+
+  // min(max(a, b), a) -> a
+  // min(max(b, a), a) -> a
+  if (auto max = getLhs().getDefiningOp<MaxUIOp>())
+    if (getRhs() == max.getRhs() || getRhs() == max.getLhs())
+      return getRhs();
+
+  // min(a, max(a, b)) -> a
+  // min(a, max(b, a)) -> a
+  if (auto max = getRhs().getDefiningOp<MaxUIOp>())
+    if (getLhs() == max.getRhs() || getLhs() == max.getLhs())
+      return getLhs();
+
   return constFoldBinaryOp<IntegerAttr>(adaptor.getOperands(),
                                         [](const APInt &a, const APInt &b) {
                                           return llvm::APIntOps::umin(a, b);
diff --git a/mlir/test/Dialect/Arith/canonicalize.mlir b/mlir/test/Dialect/Arith/canonicalize.mlir
index ca3de3a2d7703..afa53a33e79fe 100644
--- a/mlir/test/Dialect/Arith/canonicalize.mlir
+++ b/mlir/test/Dialect/Arith/canonicalize.mlir
@@ -1984,6 +1984,40 @@ func.func @test_maxsi2(%arg0 : i8) -> (i8, i8, i8, i8) {
   return %0, %1, %2, %3: i8, i8, i8, i8
 }
 
+// CHECK-LABEL:   foldMaxsiMaxsi1
+// CHECK:           %[[MAXSI:.*]] = arith.maxsi %arg1, %arg0 : i32
+// CHECK:           return %[[MAXSI]] : i32
+func.func public @foldMaxsiMaxsi1(%arg0: i32, %arg1: i32) -> i32 {
+  %max1 = arith.maxsi %arg1, %arg0 : i32
+  %max2 = arith.maxsi %max1, %arg1 : i32
+  func.return %max2 : i32
+}
+
+// CHECK-LABEL:   foldMaxsiMaxsi2
+// CHECK:           %[[MAXSI:.*]] = arith.maxsi %arg1, %arg0 : i32
+// CHECK:           return %[[MAXSI]] : i32
+func.func public @foldMaxsiMaxsi2(%arg0: i32, %arg1: i32) -> i32 {
+  %max1 = arith.maxsi %arg1, %arg0 : i32
+  %max2 = arith.maxsi %arg1, %max1 : i32
+  func.return %max2 : i32
+}
+
+// CHECK-LABEL:   foldMaxsiMinsi1
+// CHECK:           return %arg0 : i32
+func.func public @foldMaxsiMinsi1(%arg0: i32, %arg1: i32) -> i32 {
+  %min1 = arith.minsi %arg1, %arg0 : i32
+  %max2 = arith.maxsi %min1, %arg0 : i32
+  func.return %max2 : i32
+}
+
+// CHECK-LABEL:   foldMaxsiMinsi2
+// CHECK:           return %arg0 : i32
+func.func public @foldMaxsiMinsi2(%arg0: i32, %arg1: i32) -> i32 {
+  %min1 = arith.minsi %arg1, %arg0 : i32
+  %max2 = arith.maxsi %arg0, %min1 : i32
+  func.return %max2 : i32
+}
+
 // -----
 
 // CHECK-LABEL: test_maxui
@@ -2018,6 +2052,40 @@ func.func @test_maxui2(%arg0 : i8) -> (i8, i8, i8, i8) {
   return %0, %1, %2, %3: i8, i8, i8, i8
 }
 
+// CHECK-LABEL:   foldMaxuiMaxui1
+// CHECK:           %[[MAXUI:.*]] = arith.maxui %arg1, %arg0 : i32
+// CHECK:           return %[[MAXUI]] : i32
+func.func public @foldMaxuiMaxui1(%arg0: i32, %arg1: i32) -> i32 {
+  %max1 = arith.maxui %arg1, %arg0 : i32
+  %max2 = arith.maxui %max1, %arg1 : i32
+  func.return %max2 : i32
+}
+
+// CHECK-LABEL:   foldMaxuiMaxui2
+// CHECK:           %[[MAXUI:.*]] = arith.maxui %arg1, %arg0 : i32
+// CHECK:           return %[[MAXUI]] : i32
+func.func public @foldMaxuiMaxui2(%arg0: i32, %arg1: i32) -> i32 {
+  %max1 = arith.maxui %arg1, %arg0 : i32
+  %max2 = arith.maxui %arg1, %max1 : i32
+  func.return %max2 : i32
+}
+
+// CHECK-LABEL:   foldMaxuiMinui1
+// CHECK:           return %arg0 : i32
+func.func public @foldMaxuiMinui1(%arg0: i32, %arg1: i32) -> i32 {
+  %min1 = arith.minui %arg1, %arg0 : i32
+  %max2 = arith.maxui %min1, %arg0 : i32
+  func.return %max2 : i32
+}
+
+// CHECK-LABEL:   foldMaxuiMinui2
+// CHECK:           return %arg0 : i32
+func.func public @foldMaxuiMinui2(%arg0: i32, %arg1: i32) -> i32 {
+  %min1 = arith.minui %arg1, %arg0 : i32
+  %max2 = arith.maxui %arg0, %min1 : i32
+  func.return %max2 : i32
+}
+
 // -----
 
 // CHECK-LABEL: test_minsi
@@ -2052,6 +2120,40 @@ func.func @test_minsi2(%arg0 : i8) -> (i8, i8, i8, i8) {
   return %0, %1, %2, %3: i8, i8, i8, i8
 }
 
+// CHECK-LABEL:   foldMinsiMinsi1
+// CHECK:           %[[MINSI:.*]] = arith.minsi %arg1, %arg0 : i32
+// CHECK:           return %[[MINSI]] : i32
+func.func public @foldMinsiMinsi1(%arg0: i32, %arg1: i32) -> i32 {
+  %min1 = arith.minsi %arg1, %arg0 : i32
+  %min2 = arith.minsi %min1, %arg1 : i32
+  func.return %min2 : i32
+}
+
+// CHECK-LABEL:   foldMinsiMinsi2
+// CHECK:           %[[MINSI:.*]] = arith.minsi %arg1, %arg0 : i32
+// CHECK:           return %[[MINSI]] : i32
+func.func public @foldMinsiMinsi2(%arg0: i32, %arg1: i32) -> i32 {
+  %min1 = arith.minsi %arg1, %arg0 : i32
+  %min2 = arith.minsi %arg1, %min1 : i32
+  func.return %min2 : i32
+}
+
+// CHECK-LABEL:   foldMinsiMaxsi1
+// CHECK:           return %arg0 : i32
+func.func public @foldMinsiMaxsi1(%arg0: i32, %arg1: i32) -> i32 {
+  %min1 = arith.maxsi %arg1, %arg0 : i32
+  %min2 = arith.minsi %min1, %arg0 : i32
+  func.return %min2 : i32
+}
+
+// CHECK-LABEL:   foldMinsiMaxsi2
+// CHECK:           return %arg0 : i32
+func.func public @foldMinsiMaxsi2(%arg0: i32, %arg1: i32) -> i32 {
+  %min1 = arith.maxsi %arg1, %arg0 : i32
+  %min2 = arith.minsi %arg0, %min1 : i32
+  func.return %min2 : i32
+}
+
 // -----
 
 // CHECK-LABEL: test_minui
@@ -2086,6 +2188,40 @@ func.func @test_minui2(%arg0 : i8) -> (i8, i8, i8, i8) {
   return %0, %1, %2, %3: i8, i8, i8, i8
 }
 
+// CHECK-LABEL:   foldMinuiMinui1
+// CHECK:           %[[MINUI:.*]] = arith.minui %arg1, %arg0 : i32
+// CHECK:           return %[[MINUI]] : i32
+func.func public @foldMinuiMinui1(%arg0: i32, %arg1: i32) -> i32 {
+  %min1 = arith.minui %arg1, %arg0 : i32
+  %min2 = arith.minui %min1, %arg1 : i32
+  func.return %min2 : i32
+}
+
+// CHECK-LABEL:   foldMinuiMinui2
+// CHECK:           %[[MINUI:.*]] = arith.minui %arg1, %arg0 : i32
+// CHECK:           return %[[MINUI]] : i32
+func.func public @foldMinuiMinui2(%arg0: i32, %arg1: i32) -> i32 {
+  %min1 = arith.minui %arg1, %arg0 : i32
+  %min2 = arith.minui %arg1, %min1 : i32
+  func.return %min2 : i32
+}
+
+// CHECK-LABEL:   foldMinuiMaxui1
+// CHECK:           return %arg0 : i32
+func.func public @foldMinuiMaxui1(%arg0: i32, %arg1: i32) -> i32 {
+  %max1 = arith.maxui %arg1, %arg0 : i32
+  %min2 = arith.minui %max1, %arg0 : i32
+  func.return %min2 : i32
+}
+
+// CHECK-LABEL:   foldMinuiMaxui2
+// CHECK:           return %arg0 : i32
+func.func public @foldMinuiMaxui2(%arg0: i32, %arg1: i32) -> i32 {
+  %max1 = arith.maxui %arg1, %arg0 : i32
+  %min2 = arith.minui %arg0, %max1 : i32
+  func.return %min2 : i32
+}
+
 // -----
 
 // CHECK-LABEL: @test_minimumf(
@@ -3377,4 +3513,3 @@ func.func @unreachable() {
   %add = arith.addi %add, %c1_i64 : i64
   cf.br ^unreachable
 }
-

@ziliangzl ziliangzl changed the title [mlir][arith] Fold min/max ops using absorption law and redundant con… [mlir][arith] Fold min/max with absorption and redundancy Sep 23, 2025
@kuhar kuhar self-requested a review September 23, 2025 12:33
Comment on lines +1159 to +1166
// max(max(a, b), b) -> max(a, b)
// max(max(a, b), a) -> max(a, b)
if (auto max = getLhs().getDefiningOp<MaxSIOp>())
if (getRhs() == max.getRhs() || getRhs() == max.getLhs())
return getLhs();

// max(a, max(a, b)) -> max(a, b)
// max(b, max(a, b)) -> max(a, b)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

These look fine to me

Comment on lines +1171 to +1178
// max(min(a, b), a) -> a
// max(min(b, a), a) -> a
if (auto min = getLhs().getDefiningOp<MinSIOp>())
if (getRhs() == min.getRhs() || getRhs() == min.getLhs())
return getRhs();

// max(a, min(a, b)) -> a
// max(a, min(b, a)) -> a
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I was wondering if this is correct when one of the values is poison and gets discarded, but this should be fine since the optimization makes the result more defined: https://alive2.llvm.org/ce/z/cdoy8x

// CHECK-LABEL: foldMaxsiMaxsi1
// CHECK: %[[MAXSI:.*]] = arith.maxsi %arg1, %arg0 : i32
// CHECK: return %[[MAXSI]] : i32
func.func public @foldMaxsiMaxsi1(%arg0: i32, %arg1: i32) -> i32 {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you drop public from these functions? We don't need it.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done.

Copy link
Member

@kuhar kuhar left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks good overall, just one nit (public).

Could we do something similar in the floating point case? I imagine it may be more difficult depending on the exact NaN handling.

@ziliangzl
Copy link
Author

Looks good overall, just one nit (public).

Could we do something similar in the floating point case? I imagine it may be more difficult depending on the exact NaN handling.

For the floating point case: you’re right, the NaN semantics make direct folding tricky. I think a safe starting point could be to add canonicalization patterns instead, e.g. rewriting max(max(x, c1), c2) into max(x, max(c1, c2)) for integer or floating point. I can do that in a follow-up PR.

@ziliangzl
Copy link
Author

Hi @kuhar,

I’ve added the fold pattern min(min(a, b), b) -> min(a, b) for maximumf, minimumf, maxnumf, and minnumf. I believe this folding is safe. Could you please take a look and help review/merge it?

Thanks!

@ziliangzl
Copy link
Author

Hi @kuhar, I submitted another PR (#161057) that adds canonicalization patterns for nested min/max operation with constants. Could you please take a look?

// CHECK-LABEL: foldMinimumfMinimumf1
// CHECK: %[[MINF:.*]] = arith.minimumf %arg1, %arg0 : f32
// CHECK: return %[[MINF]] : f32
func.func public @foldMinimumfMinimumf1(%arg0: f32, %arg1: f32) -> f32 {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We don't need this public here and elsewhere

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks, removed.

Comment on lines +1119 to +1126
// max(max(a, b), b) -> max(a, b)
// max(max(a, b), a) -> max(a, b)
if (auto max = getLhs().getDefiningOp<MaximumFOp>())
if (getRhs() == max.getRhs() || getRhs() == max.getLhs())
return getLhs();

// max(a, max(a, b)) -> max(a, b)
// max(b, max(a, b)) -> max(a, b)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I wonder if instead of covering both cases here and for other ops, we should instead use the fact that these are commutative and first move the other min/max operand to the RHS?

Copy link
Author

@ziliangzl ziliangzl Sep 29, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good point. Currently, the Commutative trait only provides a generic fold that moves constant operands to the end:

LogicalResult
OpTrait::impl::foldCommutative(Operation *op, ArrayRef<Attribute> operands,
SmallVectorImpl<OpFoldResult> &results) {
// Nothing to fold if there are not at least 2 operands.
if (op->getNumOperands() < 2)
return failure();
// Move all constant operands to the end.
OpOperand *operandsBegin = op->getOpOperands().begin();
auto isNonConstant = [&](OpOperand &o) {
return !static_cast<bool>(operands[std::distance(operandsBegin, &o)]);
};
auto *firstConstantIt = llvm::find_if_not(op->getOpOperands(), isNonConstant);
auto *newConstantIt = std::stable_partition(
firstConstantIt, op->getOpOperands().end(), isNonConstant);
// Return success if the op was modified.
return success(firstConstantIt != newConstantIt);
}

so it helps in canonicalize patterns like max(max(x, c0), c1) -> max(x, max(c0, c1)), where there are constants.

But for cases like fold patterns in this pr max(max(a, b), a) -> max(a, b) , no constants are involved, so the commutative fold doesn’t normalize operands. That’s why I explicitly check both lhs and rhs here.

An alternative would be to normalize operands inside the fold itself (e.g., always move the nested max to the RHS), but for now I kept both conditions for clarity.
cc @joker-eph

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

An alternative would be to normalize operands inside the fold itself (e.g., always move the nested max to the RHS)

Yes, I suggested this in the other revision: doing this in the trait would align all the commutative operation and then in-turn simplify all the folders (to avoid checking redundant forms)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants