Skip to content
Open
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
96 changes: 96 additions & 0 deletions mlir/lib/Dialect/Arith/IR/ArithOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1156,6 +1156,30 @@ OpFoldResult MaxSIOp::fold(FoldAdaptor adaptor) {
return getLhs();
}

// max(max(a, b), b) -> max(a, b)
// max(max(a, b), a) -> max(a, b)
if (auto max = getLhs().getDefiningOp<MaxSIOp>())
if (getRhs() == max.getRhs() || getRhs() == max.getLhs())
return getLhs();

// max(a, max(a, b)) -> max(a, b)
// max(b, max(a, b)) -> max(a, b)
Comment on lines +1183 to +1190
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

These look fine to me

if (auto max = getRhs().getDefiningOp<MaxSIOp>())
if (getLhs() == max.getRhs() || getLhs() == max.getLhs())
return getRhs();

// max(min(a, b), a) -> a
// max(min(b, a), a) -> a
if (auto min = getLhs().getDefiningOp<MinSIOp>())
if (getRhs() == min.getRhs() || getRhs() == min.getLhs())
return getRhs();

// max(a, min(a, b)) -> a
// max(a, min(b, a)) -> a
Comment on lines +1195 to +1202
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I was wondering if this is correct when one of the values is poison and gets discarded, but this should be fine since the optimization makes the result more defined: https://alive2.llvm.org/ce/z/cdoy8x

if (auto min = getRhs().getDefiningOp<MinSIOp>())
if (getLhs() == min.getRhs() || getLhs() == min.getLhs())
return getLhs();

return constFoldBinaryOp<IntegerAttr>(adaptor.getOperands(),
[](const APInt &a, const APInt &b) {
return llvm::APIntOps::smax(a, b);
Expand All @@ -1181,6 +1205,30 @@ OpFoldResult MaxUIOp::fold(FoldAdaptor adaptor) {
return getLhs();
}

// max(max(a, b), b) -> max(a, b)
// max(max(a, b), a) -> max(a, b)
if (auto max = getLhs().getDefiningOp<MaxUIOp>())
if (getRhs() == max.getRhs() || getRhs() == max.getLhs())
return getLhs();

// max(a, max(a, b)) -> max(a, b)
// max(b, max(a, b)) -> max(a, b)
if (auto max = getRhs().getDefiningOp<MaxUIOp>())
if (getLhs() == max.getRhs() || getLhs() == max.getLhs())
return getRhs();

// max(min(a, b), a) -> a
// max(min(b, a), a) -> a
if (auto min = getLhs().getDefiningOp<MinUIOp>())
if (getRhs() == min.getRhs() || getRhs() == min.getLhs())
return getRhs();

// max(a, min(a, b)) -> a
// max(a, min(b, a)) -> a
if (auto min = getRhs().getDefiningOp<MinUIOp>())
if (getLhs() == min.getRhs() || getLhs() == min.getLhs())
return getLhs();

return constFoldBinaryOp<IntegerAttr>(adaptor.getOperands(),
[](const APInt &a, const APInt &b) {
return llvm::APIntOps::umax(a, b);
Expand Down Expand Up @@ -1242,6 +1290,30 @@ OpFoldResult MinSIOp::fold(FoldAdaptor adaptor) {
return getLhs();
}

// min(min(a, b), b) -> min(a, b)
// min(min(a, b), a) -> min(a, b)
if (auto min = getLhs().getDefiningOp<MinSIOp>())
if (getRhs() == min.getRhs() || getRhs() == min.getLhs())
return getLhs();

// min(a, min(a, b)) -> min(a, b)
// min(b, min(a, b)) -> min(a, b)
if (auto min = getRhs().getDefiningOp<MinSIOp>())
if (getLhs() == min.getRhs() || getLhs() == min.getLhs())
return getRhs();

// min(max(a, b), a) -> a
// min(max(b, a), a) -> a
if (auto max = getLhs().getDefiningOp<MaxSIOp>())
if (getRhs() == max.getRhs() || getRhs() == max.getLhs())
return getRhs();

// min(a, max(a, b)) -> a
// min(a, max(b, a)) -> a
if (auto max = getRhs().getDefiningOp<MaxSIOp>())
if (getLhs() == max.getRhs() || getLhs() == max.getLhs())
return getLhs();

return constFoldBinaryOp<IntegerAttr>(adaptor.getOperands(),
[](const APInt &a, const APInt &b) {
return llvm::APIntOps::smin(a, b);
Expand All @@ -1267,6 +1339,30 @@ OpFoldResult MinUIOp::fold(FoldAdaptor adaptor) {
return getLhs();
}

// min(min(a, b), b) -> min(a, b)
// min(min(a, b), a) -> min(a, b)
if (auto min = getLhs().getDefiningOp<MinUIOp>())
if (getRhs() == min.getRhs() || getRhs() == min.getLhs())
return getLhs();

// min(a, min(a, b)) -> min(a, b)
// min(b, min(a, b)) -> min(a, b)
if (auto min = getRhs().getDefiningOp<MinUIOp>())
if (getLhs() == min.getRhs() || getLhs() == min.getLhs())
return getRhs();

// min(max(a, b), a) -> a
// min(max(b, a), a) -> a
if (auto max = getLhs().getDefiningOp<MaxUIOp>())
if (getRhs() == max.getRhs() || getRhs() == max.getLhs())
return getRhs();

// min(a, max(a, b)) -> a
// min(a, max(b, a)) -> a
if (auto max = getRhs().getDefiningOp<MaxUIOp>())
if (getLhs() == max.getRhs() || getLhs() == max.getLhs())
return getLhs();

return constFoldBinaryOp<IntegerAttr>(adaptor.getOperands(),
[](const APInt &a, const APInt &b) {
return llvm::APIntOps::umin(a, b);
Expand Down
137 changes: 136 additions & 1 deletion mlir/test/Dialect/Arith/canonicalize.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -1984,6 +1984,40 @@ func.func @test_maxsi2(%arg0 : i8) -> (i8, i8, i8, i8) {
return %0, %1, %2, %3: i8, i8, i8, i8
}

// CHECK-LABEL: foldMaxsiMaxsi1
// CHECK: %[[MAXSI:.*]] = arith.maxsi %arg1, %arg0 : i32
// CHECK: return %[[MAXSI]] : i32
func.func public @foldMaxsiMaxsi1(%arg0: i32, %arg1: i32) -> i32 {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you drop public from these functions? We don't need it.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done.

%max1 = arith.maxsi %arg1, %arg0 : i32
%max2 = arith.maxsi %max1, %arg1 : i32
func.return %max2 : i32
}

// CHECK-LABEL: foldMaxsiMaxsi2
// CHECK: %[[MAXSI:.*]] = arith.maxsi %arg1, %arg0 : i32
// CHECK: return %[[MAXSI]] : i32
func.func public @foldMaxsiMaxsi2(%arg0: i32, %arg1: i32) -> i32 {
%max1 = arith.maxsi %arg1, %arg0 : i32
%max2 = arith.maxsi %arg1, %max1 : i32
func.return %max2 : i32
}

// CHECK-LABEL: foldMaxsiMinsi1
// CHECK: return %arg0 : i32
func.func public @foldMaxsiMinsi1(%arg0: i32, %arg1: i32) -> i32 {
%min1 = arith.minsi %arg1, %arg0 : i32
%max2 = arith.maxsi %min1, %arg0 : i32
func.return %max2 : i32
}

// CHECK-LABEL: foldMaxsiMinsi2
// CHECK: return %arg0 : i32
func.func public @foldMaxsiMinsi2(%arg0: i32, %arg1: i32) -> i32 {
%min1 = arith.minsi %arg1, %arg0 : i32
%max2 = arith.maxsi %arg0, %min1 : i32
func.return %max2 : i32
}

// -----

// CHECK-LABEL: test_maxui
Expand Down Expand Up @@ -2018,6 +2052,40 @@ func.func @test_maxui2(%arg0 : i8) -> (i8, i8, i8, i8) {
return %0, %1, %2, %3: i8, i8, i8, i8
}

// CHECK-LABEL: foldMaxuiMaxui1
// CHECK: %[[MAXUI:.*]] = arith.maxui %arg1, %arg0 : i32
// CHECK: return %[[MAXUI]] : i32
func.func public @foldMaxuiMaxui1(%arg0: i32, %arg1: i32) -> i32 {
%max1 = arith.maxui %arg1, %arg0 : i32
%max2 = arith.maxui %max1, %arg1 : i32
func.return %max2 : i32
}

// CHECK-LABEL: foldMaxuiMaxui2
// CHECK: %[[MAXUI:.*]] = arith.maxui %arg1, %arg0 : i32
// CHECK: return %[[MAXUI]] : i32
func.func public @foldMaxuiMaxui2(%arg0: i32, %arg1: i32) -> i32 {
%max1 = arith.maxui %arg1, %arg0 : i32
%max2 = arith.maxui %arg1, %max1 : i32
func.return %max2 : i32
}

// CHECK-LABEL: foldMaxuiMinui1
// CHECK: return %arg0 : i32
func.func public @foldMaxuiMinui1(%arg0: i32, %arg1: i32) -> i32 {
%min1 = arith.minui %arg1, %arg0 : i32
%max2 = arith.maxui %min1, %arg0 : i32
func.return %max2 : i32
}

// CHECK-LABEL: foldMaxuiMinui2
// CHECK: return %arg0 : i32
func.func public @foldMaxuiMinui2(%arg0: i32, %arg1: i32) -> i32 {
%min1 = arith.minui %arg1, %arg0 : i32
%max2 = arith.maxui %arg0, %min1 : i32
func.return %max2 : i32
}

// -----

// CHECK-LABEL: test_minsi
Expand Down Expand Up @@ -2052,6 +2120,40 @@ func.func @test_minsi2(%arg0 : i8) -> (i8, i8, i8, i8) {
return %0, %1, %2, %3: i8, i8, i8, i8
}

// CHECK-LABEL: foldMinsiMinsi1
// CHECK: %[[MINSI:.*]] = arith.minsi %arg1, %arg0 : i32
// CHECK: return %[[MINSI]] : i32
func.func public @foldMinsiMinsi1(%arg0: i32, %arg1: i32) -> i32 {
%min1 = arith.minsi %arg1, %arg0 : i32
%min2 = arith.minsi %min1, %arg1 : i32
func.return %min2 : i32
}

// CHECK-LABEL: foldMinsiMinsi2
// CHECK: %[[MINSI:.*]] = arith.minsi %arg1, %arg0 : i32
// CHECK: return %[[MINSI]] : i32
func.func public @foldMinsiMinsi2(%arg0: i32, %arg1: i32) -> i32 {
%min1 = arith.minsi %arg1, %arg0 : i32
%min2 = arith.minsi %arg1, %min1 : i32
func.return %min2 : i32
}

// CHECK-LABEL: foldMinsiMaxsi1
// CHECK: return %arg0 : i32
func.func public @foldMinsiMaxsi1(%arg0: i32, %arg1: i32) -> i32 {
%min1 = arith.maxsi %arg1, %arg0 : i32
%min2 = arith.minsi %min1, %arg0 : i32
func.return %min2 : i32
}

// CHECK-LABEL: foldMinsiMaxsi2
// CHECK: return %arg0 : i32
func.func public @foldMinsiMaxsi2(%arg0: i32, %arg1: i32) -> i32 {
%min1 = arith.maxsi %arg1, %arg0 : i32
%min2 = arith.minsi %arg0, %min1 : i32
func.return %min2 : i32
}

// -----

// CHECK-LABEL: test_minui
Expand Down Expand Up @@ -2086,6 +2188,40 @@ func.func @test_minui2(%arg0 : i8) -> (i8, i8, i8, i8) {
return %0, %1, %2, %3: i8, i8, i8, i8
}

// CHECK-LABEL: foldMinuiMinui1
// CHECK: %[[MINUI:.*]] = arith.minui %arg1, %arg0 : i32
// CHECK: return %[[MINUI]] : i32
func.func public @foldMinuiMinui1(%arg0: i32, %arg1: i32) -> i32 {
%min1 = arith.minui %arg1, %arg0 : i32
%min2 = arith.minui %min1, %arg1 : i32
func.return %min2 : i32
}

// CHECK-LABEL: foldMinuiMinui2
// CHECK: %[[MINUI:.*]] = arith.minui %arg1, %arg0 : i32
// CHECK: return %[[MINUI]] : i32
func.func public @foldMinuiMinui2(%arg0: i32, %arg1: i32) -> i32 {
%min1 = arith.minui %arg1, %arg0 : i32
%min2 = arith.minui %arg1, %min1 : i32
func.return %min2 : i32
}

// CHECK-LABEL: foldMinuiMaxui1
// CHECK: return %arg0 : i32
func.func public @foldMinuiMaxui1(%arg0: i32, %arg1: i32) -> i32 {
%max1 = arith.maxui %arg1, %arg0 : i32
%min2 = arith.minui %max1, %arg0 : i32
func.return %min2 : i32
}

// CHECK-LABEL: foldMinuiMaxui2
// CHECK: return %arg0 : i32
func.func public @foldMinuiMaxui2(%arg0: i32, %arg1: i32) -> i32 {
%max1 = arith.maxui %arg1, %arg0 : i32
%min2 = arith.minui %arg0, %max1 : i32
func.return %min2 : i32
}

// -----

// CHECK-LABEL: @test_minimumf(
Expand Down Expand Up @@ -3377,4 +3513,3 @@ func.func @unreachable() {
%add = arith.addi %add, %c1_i64 : i64
cf.br ^unreachable
}