-
Notifications
You must be signed in to change notification settings - Fork 15.1k
[mlir][arith] Fold min/max with absorption and redundancy #160224
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
…secutive ops Supported folding for arith.maxsi, arith.maxui, arith.minsi, and arith.minui. 1. Fold redundant consecutive min/max operations: max(max(a, b), b) -> max(a, b) max(max(a, b), a) -> max(a, b) max(a, max(a, b)) -> max(a, b) max(b, max(a, b)) -> max(a, b) (similar cases for min) 2. Fold using the absorption law: max(min(a, b), a) -> a max(min(b, a), a) -> a max(a, min(a, b)) -> a max(a, min(b, a)) -> a (similar cases for min)
|
Thank you for submitting a Pull Request (PR) to the LLVM Project! This PR will be automatically labeled and the relevant teams will be notified. If you wish to, you can add reviewers by using the "Reviewers" section on this page. If this is not working for you, it is probably because you do not have write permissions for the repository. In which case you can instead tag reviewers by name in a comment by using If you have received no comments on your PR for a week, you can request a review by "ping"ing the PR by adding a comment “Ping”. The common courtesy "ping" rate is once a week. Please remember that you are asking for valuable time from other developers. If you have further questions, they may be answered by the LLVM GitHub User Guide. You can also ask questions in a comment on this PR, on the LLVM Discord or on the forums. |
|
@llvm/pr-subscribers-mlir-arith @llvm/pr-subscribers-mlir Author: Ziliang Zhang (ziliangzl) Changes…secutive ops Supported folding for arith.maxsi, arith.maxui, arith.minsi, and arith.minui.
Full diff: https://github.com/llvm/llvm-project/pull/160224.diff 2 Files Affected:
diff --git a/mlir/lib/Dialect/Arith/IR/ArithOps.cpp b/mlir/lib/Dialect/Arith/IR/ArithOps.cpp
index 7cfd6d3a98df8..ea95d15b96f0c 100644
--- a/mlir/lib/Dialect/Arith/IR/ArithOps.cpp
+++ b/mlir/lib/Dialect/Arith/IR/ArithOps.cpp
@@ -1156,6 +1156,30 @@ OpFoldResult MaxSIOp::fold(FoldAdaptor adaptor) {
return getLhs();
}
+ // max(max(a, b), b) -> max(a, b)
+ // max(max(a, b), a) -> max(a, b)
+ if (auto max = getLhs().getDefiningOp<MaxSIOp>())
+ if (getRhs() == max.getRhs() || getRhs() == max.getLhs())
+ return getLhs();
+
+ // max(a, max(a, b)) -> max(a, b)
+ // max(b, max(a, b)) -> max(a, b)
+ if (auto max = getRhs().getDefiningOp<MaxSIOp>())
+ if (getLhs() == max.getRhs() || getLhs() == max.getLhs())
+ return getRhs();
+
+ // max(min(a, b), a) -> a
+ // max(min(b, a), a) -> a
+ if (auto min = getLhs().getDefiningOp<MinSIOp>())
+ if (getRhs() == min.getRhs() || getRhs() == min.getLhs())
+ return getRhs();
+
+ // max(a, min(a, b)) -> a
+ // max(a, min(b, a)) -> a
+ if (auto min = getRhs().getDefiningOp<MinSIOp>())
+ if (getLhs() == min.getRhs() || getLhs() == min.getLhs())
+ return getLhs();
+
return constFoldBinaryOp<IntegerAttr>(adaptor.getOperands(),
[](const APInt &a, const APInt &b) {
return llvm::APIntOps::smax(a, b);
@@ -1181,6 +1205,30 @@ OpFoldResult MaxUIOp::fold(FoldAdaptor adaptor) {
return getLhs();
}
+ // max(max(a, b), b) -> max(a, b)
+ // max(max(a, b), a) -> max(a, b)
+ if (auto max = getLhs().getDefiningOp<MaxUIOp>())
+ if (getRhs() == max.getRhs() || getRhs() == max.getLhs())
+ return getLhs();
+
+ // max(a, max(a, b)) -> max(a, b)
+ // max(b, max(a, b)) -> max(a, b)
+ if (auto max = getRhs().getDefiningOp<MaxUIOp>())
+ if (getLhs() == max.getRhs() || getLhs() == max.getLhs())
+ return getRhs();
+
+ // max(min(a, b), a) -> a
+ // max(min(b, a), a) -> a
+ if (auto min = getLhs().getDefiningOp<MinUIOp>())
+ if (getRhs() == min.getRhs() || getRhs() == min.getLhs())
+ return getRhs();
+
+ // max(a, min(a, b)) -> a
+ // max(a, min(b, a)) -> a
+ if (auto min = getRhs().getDefiningOp<MinUIOp>())
+ if (getLhs() == min.getRhs() || getLhs() == min.getLhs())
+ return getLhs();
+
return constFoldBinaryOp<IntegerAttr>(adaptor.getOperands(),
[](const APInt &a, const APInt &b) {
return llvm::APIntOps::umax(a, b);
@@ -1242,6 +1290,30 @@ OpFoldResult MinSIOp::fold(FoldAdaptor adaptor) {
return getLhs();
}
+ // min(min(a, b), b) -> min(a, b)
+ // min(min(a, b), a) -> min(a, b)
+ if (auto min = getLhs().getDefiningOp<MinSIOp>())
+ if (getRhs() == min.getRhs() || getRhs() == min.getLhs())
+ return getLhs();
+
+ // min(a, min(a, b)) -> min(a, b)
+ // min(b, min(a, b)) -> min(a, b)
+ if (auto min = getRhs().getDefiningOp<MinSIOp>())
+ if (getLhs() == min.getRhs() || getLhs() == min.getLhs())
+ return getRhs();
+
+ // min(max(a, b), a) -> a
+ // min(max(b, a), a) -> a
+ if (auto max = getLhs().getDefiningOp<MaxSIOp>())
+ if (getRhs() == max.getRhs() || getRhs() == max.getLhs())
+ return getRhs();
+
+ // min(a, max(a, b)) -> a
+ // min(a, max(b, a)) -> a
+ if (auto max = getRhs().getDefiningOp<MaxSIOp>())
+ if (getLhs() == max.getRhs() || getLhs() == max.getLhs())
+ return getLhs();
+
return constFoldBinaryOp<IntegerAttr>(adaptor.getOperands(),
[](const APInt &a, const APInt &b) {
return llvm::APIntOps::smin(a, b);
@@ -1267,6 +1339,30 @@ OpFoldResult MinUIOp::fold(FoldAdaptor adaptor) {
return getLhs();
}
+ // min(min(a, b), b) -> min(a, b)
+ // min(min(a, b), a) -> min(a, b)
+ if (auto min = getLhs().getDefiningOp<MinUIOp>())
+ if (getRhs() == min.getRhs() || getRhs() == min.getLhs())
+ return getLhs();
+
+ // min(a, min(a, b)) -> min(a, b)
+ // min(b, min(a, b)) -> min(a, b)
+ if (auto min = getRhs().getDefiningOp<MinUIOp>())
+ if (getLhs() == min.getRhs() || getLhs() == min.getLhs())
+ return getRhs();
+
+ // min(max(a, b), a) -> a
+ // min(max(b, a), a) -> a
+ if (auto max = getLhs().getDefiningOp<MaxUIOp>())
+ if (getRhs() == max.getRhs() || getRhs() == max.getLhs())
+ return getRhs();
+
+ // min(a, max(a, b)) -> a
+ // min(a, max(b, a)) -> a
+ if (auto max = getRhs().getDefiningOp<MaxUIOp>())
+ if (getLhs() == max.getRhs() || getLhs() == max.getLhs())
+ return getLhs();
+
return constFoldBinaryOp<IntegerAttr>(adaptor.getOperands(),
[](const APInt &a, const APInt &b) {
return llvm::APIntOps::umin(a, b);
diff --git a/mlir/test/Dialect/Arith/canonicalize.mlir b/mlir/test/Dialect/Arith/canonicalize.mlir
index ca3de3a2d7703..afa53a33e79fe 100644
--- a/mlir/test/Dialect/Arith/canonicalize.mlir
+++ b/mlir/test/Dialect/Arith/canonicalize.mlir
@@ -1984,6 +1984,40 @@ func.func @test_maxsi2(%arg0 : i8) -> (i8, i8, i8, i8) {
return %0, %1, %2, %3: i8, i8, i8, i8
}
+// CHECK-LABEL: foldMaxsiMaxsi1
+// CHECK: %[[MAXSI:.*]] = arith.maxsi %arg1, %arg0 : i32
+// CHECK: return %[[MAXSI]] : i32
+func.func public @foldMaxsiMaxsi1(%arg0: i32, %arg1: i32) -> i32 {
+ %max1 = arith.maxsi %arg1, %arg0 : i32
+ %max2 = arith.maxsi %max1, %arg1 : i32
+ func.return %max2 : i32
+}
+
+// CHECK-LABEL: foldMaxsiMaxsi2
+// CHECK: %[[MAXSI:.*]] = arith.maxsi %arg1, %arg0 : i32
+// CHECK: return %[[MAXSI]] : i32
+func.func public @foldMaxsiMaxsi2(%arg0: i32, %arg1: i32) -> i32 {
+ %max1 = arith.maxsi %arg1, %arg0 : i32
+ %max2 = arith.maxsi %arg1, %max1 : i32
+ func.return %max2 : i32
+}
+
+// CHECK-LABEL: foldMaxsiMinsi1
+// CHECK: return %arg0 : i32
+func.func public @foldMaxsiMinsi1(%arg0: i32, %arg1: i32) -> i32 {
+ %min1 = arith.minsi %arg1, %arg0 : i32
+ %max2 = arith.maxsi %min1, %arg0 : i32
+ func.return %max2 : i32
+}
+
+// CHECK-LABEL: foldMaxsiMinsi2
+// CHECK: return %arg0 : i32
+func.func public @foldMaxsiMinsi2(%arg0: i32, %arg1: i32) -> i32 {
+ %min1 = arith.minsi %arg1, %arg0 : i32
+ %max2 = arith.maxsi %arg0, %min1 : i32
+ func.return %max2 : i32
+}
+
// -----
// CHECK-LABEL: test_maxui
@@ -2018,6 +2052,40 @@ func.func @test_maxui2(%arg0 : i8) -> (i8, i8, i8, i8) {
return %0, %1, %2, %3: i8, i8, i8, i8
}
+// CHECK-LABEL: foldMaxuiMaxui1
+// CHECK: %[[MAXUI:.*]] = arith.maxui %arg1, %arg0 : i32
+// CHECK: return %[[MAXUI]] : i32
+func.func public @foldMaxuiMaxui1(%arg0: i32, %arg1: i32) -> i32 {
+ %max1 = arith.maxui %arg1, %arg0 : i32
+ %max2 = arith.maxui %max1, %arg1 : i32
+ func.return %max2 : i32
+}
+
+// CHECK-LABEL: foldMaxuiMaxui2
+// CHECK: %[[MAXUI:.*]] = arith.maxui %arg1, %arg0 : i32
+// CHECK: return %[[MAXUI]] : i32
+func.func public @foldMaxuiMaxui2(%arg0: i32, %arg1: i32) -> i32 {
+ %max1 = arith.maxui %arg1, %arg0 : i32
+ %max2 = arith.maxui %arg1, %max1 : i32
+ func.return %max2 : i32
+}
+
+// CHECK-LABEL: foldMaxuiMinui1
+// CHECK: return %arg0 : i32
+func.func public @foldMaxuiMinui1(%arg0: i32, %arg1: i32) -> i32 {
+ %min1 = arith.minui %arg1, %arg0 : i32
+ %max2 = arith.maxui %min1, %arg0 : i32
+ func.return %max2 : i32
+}
+
+// CHECK-LABEL: foldMaxuiMinui2
+// CHECK: return %arg0 : i32
+func.func public @foldMaxuiMinui2(%arg0: i32, %arg1: i32) -> i32 {
+ %min1 = arith.minui %arg1, %arg0 : i32
+ %max2 = arith.maxui %arg0, %min1 : i32
+ func.return %max2 : i32
+}
+
// -----
// CHECK-LABEL: test_minsi
@@ -2052,6 +2120,40 @@ func.func @test_minsi2(%arg0 : i8) -> (i8, i8, i8, i8) {
return %0, %1, %2, %3: i8, i8, i8, i8
}
+// CHECK-LABEL: foldMinsiMinsi1
+// CHECK: %[[MINSI:.*]] = arith.minsi %arg1, %arg0 : i32
+// CHECK: return %[[MINSI]] : i32
+func.func public @foldMinsiMinsi1(%arg0: i32, %arg1: i32) -> i32 {
+ %min1 = arith.minsi %arg1, %arg0 : i32
+ %min2 = arith.minsi %min1, %arg1 : i32
+ func.return %min2 : i32
+}
+
+// CHECK-LABEL: foldMinsiMinsi2
+// CHECK: %[[MINSI:.*]] = arith.minsi %arg1, %arg0 : i32
+// CHECK: return %[[MINSI]] : i32
+func.func public @foldMinsiMinsi2(%arg0: i32, %arg1: i32) -> i32 {
+ %min1 = arith.minsi %arg1, %arg0 : i32
+ %min2 = arith.minsi %arg1, %min1 : i32
+ func.return %min2 : i32
+}
+
+// CHECK-LABEL: foldMinsiMaxsi1
+// CHECK: return %arg0 : i32
+func.func public @foldMinsiMaxsi1(%arg0: i32, %arg1: i32) -> i32 {
+ %min1 = arith.maxsi %arg1, %arg0 : i32
+ %min2 = arith.minsi %min1, %arg0 : i32
+ func.return %min2 : i32
+}
+
+// CHECK-LABEL: foldMinsiMaxsi2
+// CHECK: return %arg0 : i32
+func.func public @foldMinsiMaxsi2(%arg0: i32, %arg1: i32) -> i32 {
+ %min1 = arith.maxsi %arg1, %arg0 : i32
+ %min2 = arith.minsi %arg0, %min1 : i32
+ func.return %min2 : i32
+}
+
// -----
// CHECK-LABEL: test_minui
@@ -2086,6 +2188,40 @@ func.func @test_minui2(%arg0 : i8) -> (i8, i8, i8, i8) {
return %0, %1, %2, %3: i8, i8, i8, i8
}
+// CHECK-LABEL: foldMinuiMinui1
+// CHECK: %[[MINUI:.*]] = arith.minui %arg1, %arg0 : i32
+// CHECK: return %[[MINUI]] : i32
+func.func public @foldMinuiMinui1(%arg0: i32, %arg1: i32) -> i32 {
+ %min1 = arith.minui %arg1, %arg0 : i32
+ %min2 = arith.minui %min1, %arg1 : i32
+ func.return %min2 : i32
+}
+
+// CHECK-LABEL: foldMinuiMinui2
+// CHECK: %[[MINUI:.*]] = arith.minui %arg1, %arg0 : i32
+// CHECK: return %[[MINUI]] : i32
+func.func public @foldMinuiMinui2(%arg0: i32, %arg1: i32) -> i32 {
+ %min1 = arith.minui %arg1, %arg0 : i32
+ %min2 = arith.minui %arg1, %min1 : i32
+ func.return %min2 : i32
+}
+
+// CHECK-LABEL: foldMinuiMaxui1
+// CHECK: return %arg0 : i32
+func.func public @foldMinuiMaxui1(%arg0: i32, %arg1: i32) -> i32 {
+ %max1 = arith.maxui %arg1, %arg0 : i32
+ %min2 = arith.minui %max1, %arg0 : i32
+ func.return %min2 : i32
+}
+
+// CHECK-LABEL: foldMinuiMaxui2
+// CHECK: return %arg0 : i32
+func.func public @foldMinuiMaxui2(%arg0: i32, %arg1: i32) -> i32 {
+ %max1 = arith.maxui %arg1, %arg0 : i32
+ %min2 = arith.minui %arg0, %max1 : i32
+ func.return %min2 : i32
+}
+
// -----
// CHECK-LABEL: @test_minimumf(
@@ -3377,4 +3513,3 @@ func.func @unreachable() {
%add = arith.addi %add, %c1_i64 : i64
cf.br ^unreachable
}
-
|
| // max(max(a, b), b) -> max(a, b) | ||
| // max(max(a, b), a) -> max(a, b) | ||
| if (auto max = getLhs().getDefiningOp<MaxSIOp>()) | ||
| if (getRhs() == max.getRhs() || getRhs() == max.getLhs()) | ||
| return getLhs(); | ||
|
|
||
| // max(a, max(a, b)) -> max(a, b) | ||
| // max(b, max(a, b)) -> max(a, b) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
These look fine to me
| // max(min(a, b), a) -> a | ||
| // max(min(b, a), a) -> a | ||
| if (auto min = getLhs().getDefiningOp<MinSIOp>()) | ||
| if (getRhs() == min.getRhs() || getRhs() == min.getLhs()) | ||
| return getRhs(); | ||
|
|
||
| // max(a, min(a, b)) -> a | ||
| // max(a, min(b, a)) -> a |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I was wondering if this is correct when one of the values is poison and gets discarded, but this should be fine since the optimization makes the result more defined: https://alive2.llvm.org/ce/z/cdoy8x
| // CHECK-LABEL: foldMaxsiMaxsi1 | ||
| // CHECK: %[[MAXSI:.*]] = arith.maxsi %arg1, %arg0 : i32 | ||
| // CHECK: return %[[MAXSI]] : i32 | ||
| func.func public @foldMaxsiMaxsi1(%arg0: i32, %arg1: i32) -> i32 { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can you drop public from these functions? We don't need it.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks good overall, just one nit (public).
Could we do something similar in the floating point case? I imagine it may be more difficult depending on the exact NaN handling.
3167e32 to
69df905
Compare
For the floating point case: you’re right, the NaN semantics make direct folding tricky. I think a safe starting point could be to add canonicalization patterns instead, e.g. rewriting |
|
Hi @kuhar, I’ve added the fold pattern Thanks! |
| // CHECK-LABEL: foldMinimumfMinimumf1 | ||
| // CHECK: %[[MINF:.*]] = arith.minimumf %arg1, %arg0 : f32 | ||
| // CHECK: return %[[MINF]] : f32 | ||
| func.func public @foldMinimumfMinimumf1(%arg0: f32, %arg1: f32) -> f32 { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We don't need this public here and elsewhere
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks, removed.
| // max(max(a, b), b) -> max(a, b) | ||
| // max(max(a, b), a) -> max(a, b) | ||
| if (auto max = getLhs().getDefiningOp<MaximumFOp>()) | ||
| if (getRhs() == max.getRhs() || getRhs() == max.getLhs()) | ||
| return getLhs(); | ||
|
|
||
| // max(a, max(a, b)) -> max(a, b) | ||
| // max(b, max(a, b)) -> max(a, b) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I wonder if instead of covering both cases here and for other ops, we should instead use the fact that these are commutative and first move the other min/max operand to the RHS?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Good point. Currently, the Commutative trait only provides a generic fold that moves constant operands to the end:
llvm-project/mlir/lib/IR/Operation.cpp
Lines 856 to 873 in cac0635
| LogicalResult | |
| OpTrait::impl::foldCommutative(Operation *op, ArrayRef<Attribute> operands, | |
| SmallVectorImpl<OpFoldResult> &results) { | |
| // Nothing to fold if there are not at least 2 operands. | |
| if (op->getNumOperands() < 2) | |
| return failure(); | |
| // Move all constant operands to the end. | |
| OpOperand *operandsBegin = op->getOpOperands().begin(); | |
| auto isNonConstant = [&](OpOperand &o) { | |
| return !static_cast<bool>(operands[std::distance(operandsBegin, &o)]); | |
| }; | |
| auto *firstConstantIt = llvm::find_if_not(op->getOpOperands(), isNonConstant); | |
| auto *newConstantIt = std::stable_partition( | |
| firstConstantIt, op->getOpOperands().end(), isNonConstant); | |
| // Return success if the op was modified. | |
| return success(firstConstantIt != newConstantIt); | |
| } | |
so it helps in canonicalize patterns like max(max(x, c0), c1) -> max(x, max(c0, c1)), where there are constants.
But for cases like fold patterns in this pr max(max(a, b), a) -> max(a, b) , no constants are involved, so the commutative fold doesn’t normalize operands. That’s why I explicitly check both lhs and rhs here.
An alternative would be to normalize operands inside the fold itself (e.g., always move the nested max to the RHS), but for now I kept both conditions for clarity.
cc @joker-eph
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
An alternative would be to normalize operands inside the fold itself (e.g., always move the nested max to the RHS)
Yes, I suggested this in the other revision: doing this in the trait would align all the commutative operation and then in-turn simplify all the folders (to avoid checking redundant forms)
Supported folding for arith.maxsi, arith.maxui, arith.minsi, arith.minui, arith.maxnumf, arith.minnumf, arith.maximumf, arith,minimumf.
Fold redundant consecutive min/max operations:
max(max(a, b), b) -> max(a, b)
max(max(a, b), a) -> max(a, b)
max(a, max(a, b)) -> max(a, b)
max(b, max(a, b)) -> max(a, b)
(similar cases for min)
Fold using the absorption law:
max(min(a, b), a) -> a
max(min(b, a), a) -> a
max(a, min(a, b)) -> a
max(a, min(b, a)) -> a
(similar cases for min; not applicable to floating-point)