From e5d957e87da0b1db107525da6b6c452968f4e6e9 Mon Sep 17 00:00:00 2001 From: Ziliang Zhang Date: Tue, 23 Sep 2025 11:09:38 +0800 Subject: [PATCH 1/4] [mlir][arith] Fold min/max ops using absorption law and redundant consecutive ops Supported folding for arith.maxsi, arith.maxui, arith.minsi, and arith.minui. 1. Fold redundant consecutive min/max operations: max(max(a, b), b) -> max(a, b) max(max(a, b), a) -> max(a, b) max(a, max(a, b)) -> max(a, b) max(b, max(a, b)) -> max(a, b) (similar cases for min) 2. Fold using the absorption law: max(min(a, b), a) -> a max(min(b, a), a) -> a max(a, min(a, b)) -> a max(a, min(b, a)) -> a (similar cases for min) --- mlir/lib/Dialect/Arith/IR/ArithOps.cpp | 96 +++++++++++++++ mlir/test/Dialect/Arith/canonicalize.mlir | 137 +++++++++++++++++++++- 2 files changed, 232 insertions(+), 1 deletion(-) diff --git a/mlir/lib/Dialect/Arith/IR/ArithOps.cpp b/mlir/lib/Dialect/Arith/IR/ArithOps.cpp index 7cfd6d3a98df8..ea95d15b96f0c 100644 --- a/mlir/lib/Dialect/Arith/IR/ArithOps.cpp +++ b/mlir/lib/Dialect/Arith/IR/ArithOps.cpp @@ -1156,6 +1156,30 @@ OpFoldResult MaxSIOp::fold(FoldAdaptor adaptor) { return getLhs(); } + // max(max(a, b), b) -> max(a, b) + // max(max(a, b), a) -> max(a, b) + if (auto max = getLhs().getDefiningOp()) + if (getRhs() == max.getRhs() || getRhs() == max.getLhs()) + return getLhs(); + + // max(a, max(a, b)) -> max(a, b) + // max(b, max(a, b)) -> max(a, b) + if (auto max = getRhs().getDefiningOp()) + if (getLhs() == max.getRhs() || getLhs() == max.getLhs()) + return getRhs(); + + // max(min(a, b), a) -> a + // max(min(b, a), a) -> a + if (auto min = getLhs().getDefiningOp()) + if (getRhs() == min.getRhs() || getRhs() == min.getLhs()) + return getRhs(); + + // max(a, min(a, b)) -> a + // max(a, min(b, a)) -> a + if (auto min = getRhs().getDefiningOp()) + if (getLhs() == min.getRhs() || getLhs() == min.getLhs()) + return getLhs(); + return constFoldBinaryOp(adaptor.getOperands(), [](const APInt &a, const APInt &b) { return llvm::APIntOps::smax(a, b); @@ -1181,6 +1205,30 @@ OpFoldResult MaxUIOp::fold(FoldAdaptor adaptor) { return getLhs(); } + // max(max(a, b), b) -> max(a, b) + // max(max(a, b), a) -> max(a, b) + if (auto max = getLhs().getDefiningOp()) + if (getRhs() == max.getRhs() || getRhs() == max.getLhs()) + return getLhs(); + + // max(a, max(a, b)) -> max(a, b) + // max(b, max(a, b)) -> max(a, b) + if (auto max = getRhs().getDefiningOp()) + if (getLhs() == max.getRhs() || getLhs() == max.getLhs()) + return getRhs(); + + // max(min(a, b), a) -> a + // max(min(b, a), a) -> a + if (auto min = getLhs().getDefiningOp()) + if (getRhs() == min.getRhs() || getRhs() == min.getLhs()) + return getRhs(); + + // max(a, min(a, b)) -> a + // max(a, min(b, a)) -> a + if (auto min = getRhs().getDefiningOp()) + if (getLhs() == min.getRhs() || getLhs() == min.getLhs()) + return getLhs(); + return constFoldBinaryOp(adaptor.getOperands(), [](const APInt &a, const APInt &b) { return llvm::APIntOps::umax(a, b); @@ -1242,6 +1290,30 @@ OpFoldResult MinSIOp::fold(FoldAdaptor adaptor) { return getLhs(); } + // min(min(a, b), b) -> min(a, b) + // min(min(a, b), a) -> min(a, b) + if (auto min = getLhs().getDefiningOp()) + if (getRhs() == min.getRhs() || getRhs() == min.getLhs()) + return getLhs(); + + // min(a, min(a, b)) -> min(a, b) + // min(b, min(a, b)) -> min(a, b) + if (auto min = getRhs().getDefiningOp()) + if (getLhs() == min.getRhs() || getLhs() == min.getLhs()) + return getRhs(); + + // min(max(a, b), a) -> a + // min(max(b, a), a) -> a + if (auto max = getLhs().getDefiningOp()) + if (getRhs() == max.getRhs() || getRhs() == max.getLhs()) + return getRhs(); + + // min(a, max(a, b)) -> a + // min(a, max(b, a)) -> a + if (auto max = getRhs().getDefiningOp()) + if (getLhs() == max.getRhs() || getLhs() == max.getLhs()) + return getLhs(); + return constFoldBinaryOp(adaptor.getOperands(), [](const APInt &a, const APInt &b) { return llvm::APIntOps::smin(a, b); @@ -1267,6 +1339,30 @@ OpFoldResult MinUIOp::fold(FoldAdaptor adaptor) { return getLhs(); } + // min(min(a, b), b) -> min(a, b) + // min(min(a, b), a) -> min(a, b) + if (auto min = getLhs().getDefiningOp()) + if (getRhs() == min.getRhs() || getRhs() == min.getLhs()) + return getLhs(); + + // min(a, min(a, b)) -> min(a, b) + // min(b, min(a, b)) -> min(a, b) + if (auto min = getRhs().getDefiningOp()) + if (getLhs() == min.getRhs() || getLhs() == min.getLhs()) + return getRhs(); + + // min(max(a, b), a) -> a + // min(max(b, a), a) -> a + if (auto max = getLhs().getDefiningOp()) + if (getRhs() == max.getRhs() || getRhs() == max.getLhs()) + return getRhs(); + + // min(a, max(a, b)) -> a + // min(a, max(b, a)) -> a + if (auto max = getRhs().getDefiningOp()) + if (getLhs() == max.getRhs() || getLhs() == max.getLhs()) + return getLhs(); + return constFoldBinaryOp(adaptor.getOperands(), [](const APInt &a, const APInt &b) { return llvm::APIntOps::umin(a, b); diff --git a/mlir/test/Dialect/Arith/canonicalize.mlir b/mlir/test/Dialect/Arith/canonicalize.mlir index ca3de3a2d7703..afa53a33e79fe 100644 --- a/mlir/test/Dialect/Arith/canonicalize.mlir +++ b/mlir/test/Dialect/Arith/canonicalize.mlir @@ -1984,6 +1984,40 @@ func.func @test_maxsi2(%arg0 : i8) -> (i8, i8, i8, i8) { return %0, %1, %2, %3: i8, i8, i8, i8 } +// CHECK-LABEL: foldMaxsiMaxsi1 +// CHECK: %[[MAXSI:.*]] = arith.maxsi %arg1, %arg0 : i32 +// CHECK: return %[[MAXSI]] : i32 +func.func public @foldMaxsiMaxsi1(%arg0: i32, %arg1: i32) -> i32 { + %max1 = arith.maxsi %arg1, %arg0 : i32 + %max2 = arith.maxsi %max1, %arg1 : i32 + func.return %max2 : i32 +} + +// CHECK-LABEL: foldMaxsiMaxsi2 +// CHECK: %[[MAXSI:.*]] = arith.maxsi %arg1, %arg0 : i32 +// CHECK: return %[[MAXSI]] : i32 +func.func public @foldMaxsiMaxsi2(%arg0: i32, %arg1: i32) -> i32 { + %max1 = arith.maxsi %arg1, %arg0 : i32 + %max2 = arith.maxsi %arg1, %max1 : i32 + func.return %max2 : i32 +} + +// CHECK-LABEL: foldMaxsiMinsi1 +// CHECK: return %arg0 : i32 +func.func public @foldMaxsiMinsi1(%arg0: i32, %arg1: i32) -> i32 { + %min1 = arith.minsi %arg1, %arg0 : i32 + %max2 = arith.maxsi %min1, %arg0 : i32 + func.return %max2 : i32 +} + +// CHECK-LABEL: foldMaxsiMinsi2 +// CHECK: return %arg0 : i32 +func.func public @foldMaxsiMinsi2(%arg0: i32, %arg1: i32) -> i32 { + %min1 = arith.minsi %arg1, %arg0 : i32 + %max2 = arith.maxsi %arg0, %min1 : i32 + func.return %max2 : i32 +} + // ----- // CHECK-LABEL: test_maxui @@ -2018,6 +2052,40 @@ func.func @test_maxui2(%arg0 : i8) -> (i8, i8, i8, i8) { return %0, %1, %2, %3: i8, i8, i8, i8 } +// CHECK-LABEL: foldMaxuiMaxui1 +// CHECK: %[[MAXUI:.*]] = arith.maxui %arg1, %arg0 : i32 +// CHECK: return %[[MAXUI]] : i32 +func.func public @foldMaxuiMaxui1(%arg0: i32, %arg1: i32) -> i32 { + %max1 = arith.maxui %arg1, %arg0 : i32 + %max2 = arith.maxui %max1, %arg1 : i32 + func.return %max2 : i32 +} + +// CHECK-LABEL: foldMaxuiMaxui2 +// CHECK: %[[MAXUI:.*]] = arith.maxui %arg1, %arg0 : i32 +// CHECK: return %[[MAXUI]] : i32 +func.func public @foldMaxuiMaxui2(%arg0: i32, %arg1: i32) -> i32 { + %max1 = arith.maxui %arg1, %arg0 : i32 + %max2 = arith.maxui %arg1, %max1 : i32 + func.return %max2 : i32 +} + +// CHECK-LABEL: foldMaxuiMinui1 +// CHECK: return %arg0 : i32 +func.func public @foldMaxuiMinui1(%arg0: i32, %arg1: i32) -> i32 { + %min1 = arith.minui %arg1, %arg0 : i32 + %max2 = arith.maxui %min1, %arg0 : i32 + func.return %max2 : i32 +} + +// CHECK-LABEL: foldMaxuiMinui2 +// CHECK: return %arg0 : i32 +func.func public @foldMaxuiMinui2(%arg0: i32, %arg1: i32) -> i32 { + %min1 = arith.minui %arg1, %arg0 : i32 + %max2 = arith.maxui %arg0, %min1 : i32 + func.return %max2 : i32 +} + // ----- // CHECK-LABEL: test_minsi @@ -2052,6 +2120,40 @@ func.func @test_minsi2(%arg0 : i8) -> (i8, i8, i8, i8) { return %0, %1, %2, %3: i8, i8, i8, i8 } +// CHECK-LABEL: foldMinsiMinsi1 +// CHECK: %[[MINSI:.*]] = arith.minsi %arg1, %arg0 : i32 +// CHECK: return %[[MINSI]] : i32 +func.func public @foldMinsiMinsi1(%arg0: i32, %arg1: i32) -> i32 { + %min1 = arith.minsi %arg1, %arg0 : i32 + %min2 = arith.minsi %min1, %arg1 : i32 + func.return %min2 : i32 +} + +// CHECK-LABEL: foldMinsiMinsi2 +// CHECK: %[[MINSI:.*]] = arith.minsi %arg1, %arg0 : i32 +// CHECK: return %[[MINSI]] : i32 +func.func public @foldMinsiMinsi2(%arg0: i32, %arg1: i32) -> i32 { + %min1 = arith.minsi %arg1, %arg0 : i32 + %min2 = arith.minsi %arg1, %min1 : i32 + func.return %min2 : i32 +} + +// CHECK-LABEL: foldMinsiMaxsi1 +// CHECK: return %arg0 : i32 +func.func public @foldMinsiMaxsi1(%arg0: i32, %arg1: i32) -> i32 { + %min1 = arith.maxsi %arg1, %arg0 : i32 + %min2 = arith.minsi %min1, %arg0 : i32 + func.return %min2 : i32 +} + +// CHECK-LABEL: foldMinsiMaxsi2 +// CHECK: return %arg0 : i32 +func.func public @foldMinsiMaxsi2(%arg0: i32, %arg1: i32) -> i32 { + %min1 = arith.maxsi %arg1, %arg0 : i32 + %min2 = arith.minsi %arg0, %min1 : i32 + func.return %min2 : i32 +} + // ----- // CHECK-LABEL: test_minui @@ -2086,6 +2188,40 @@ func.func @test_minui2(%arg0 : i8) -> (i8, i8, i8, i8) { return %0, %1, %2, %3: i8, i8, i8, i8 } +// CHECK-LABEL: foldMinuiMinui1 +// CHECK: %[[MINUI:.*]] = arith.minui %arg1, %arg0 : i32 +// CHECK: return %[[MINUI]] : i32 +func.func public @foldMinuiMinui1(%arg0: i32, %arg1: i32) -> i32 { + %min1 = arith.minui %arg1, %arg0 : i32 + %min2 = arith.minui %min1, %arg1 : i32 + func.return %min2 : i32 +} + +// CHECK-LABEL: foldMinuiMinui2 +// CHECK: %[[MINUI:.*]] = arith.minui %arg1, %arg0 : i32 +// CHECK: return %[[MINUI]] : i32 +func.func public @foldMinuiMinui2(%arg0: i32, %arg1: i32) -> i32 { + %min1 = arith.minui %arg1, %arg0 : i32 + %min2 = arith.minui %arg1, %min1 : i32 + func.return %min2 : i32 +} + +// CHECK-LABEL: foldMinuiMaxui1 +// CHECK: return %arg0 : i32 +func.func public @foldMinuiMaxui1(%arg0: i32, %arg1: i32) -> i32 { + %max1 = arith.maxui %arg1, %arg0 : i32 + %min2 = arith.minui %max1, %arg0 : i32 + func.return %min2 : i32 +} + +// CHECK-LABEL: foldMinuiMaxui2 +// CHECK: return %arg0 : i32 +func.func public @foldMinuiMaxui2(%arg0: i32, %arg1: i32) -> i32 { + %max1 = arith.maxui %arg1, %arg0 : i32 + %min2 = arith.minui %arg0, %max1 : i32 + func.return %min2 : i32 +} + // ----- // CHECK-LABEL: @test_minimumf( @@ -3377,4 +3513,3 @@ func.func @unreachable() { %add = arith.addi %add, %c1_i64 : i64 cf.br ^unreachable } - From 69df9053894c47006719ed9ea5f4b64b1ae79ea0 Mon Sep 17 00:00:00 2001 From: Ziliang Zhang Date: Wed, 24 Sep 2025 10:04:07 +0800 Subject: [PATCH 2/4] remove public --- mlir/test/Dialect/Arith/canonicalize.mlir | 32 +++++++++++------------ 1 file changed, 16 insertions(+), 16 deletions(-) diff --git a/mlir/test/Dialect/Arith/canonicalize.mlir b/mlir/test/Dialect/Arith/canonicalize.mlir index afa53a33e79fe..3eba98670d325 100644 --- a/mlir/test/Dialect/Arith/canonicalize.mlir +++ b/mlir/test/Dialect/Arith/canonicalize.mlir @@ -1987,7 +1987,7 @@ func.func @test_maxsi2(%arg0 : i8) -> (i8, i8, i8, i8) { // CHECK-LABEL: foldMaxsiMaxsi1 // CHECK: %[[MAXSI:.*]] = arith.maxsi %arg1, %arg0 : i32 // CHECK: return %[[MAXSI]] : i32 -func.func public @foldMaxsiMaxsi1(%arg0: i32, %arg1: i32) -> i32 { +func.func @foldMaxsiMaxsi1(%arg0: i32, %arg1: i32) -> i32 { %max1 = arith.maxsi %arg1, %arg0 : i32 %max2 = arith.maxsi %max1, %arg1 : i32 func.return %max2 : i32 @@ -1996,7 +1996,7 @@ func.func public @foldMaxsiMaxsi1(%arg0: i32, %arg1: i32) -> i32 { // CHECK-LABEL: foldMaxsiMaxsi2 // CHECK: %[[MAXSI:.*]] = arith.maxsi %arg1, %arg0 : i32 // CHECK: return %[[MAXSI]] : i32 -func.func public @foldMaxsiMaxsi2(%arg0: i32, %arg1: i32) -> i32 { +func.func @foldMaxsiMaxsi2(%arg0: i32, %arg1: i32) -> i32 { %max1 = arith.maxsi %arg1, %arg0 : i32 %max2 = arith.maxsi %arg1, %max1 : i32 func.return %max2 : i32 @@ -2004,7 +2004,7 @@ func.func public @foldMaxsiMaxsi2(%arg0: i32, %arg1: i32) -> i32 { // CHECK-LABEL: foldMaxsiMinsi1 // CHECK: return %arg0 : i32 -func.func public @foldMaxsiMinsi1(%arg0: i32, %arg1: i32) -> i32 { +func.func @foldMaxsiMinsi1(%arg0: i32, %arg1: i32) -> i32 { %min1 = arith.minsi %arg1, %arg0 : i32 %max2 = arith.maxsi %min1, %arg0 : i32 func.return %max2 : i32 @@ -2012,7 +2012,7 @@ func.func public @foldMaxsiMinsi1(%arg0: i32, %arg1: i32) -> i32 { // CHECK-LABEL: foldMaxsiMinsi2 // CHECK: return %arg0 : i32 -func.func public @foldMaxsiMinsi2(%arg0: i32, %arg1: i32) -> i32 { +func.func @foldMaxsiMinsi2(%arg0: i32, %arg1: i32) -> i32 { %min1 = arith.minsi %arg1, %arg0 : i32 %max2 = arith.maxsi %arg0, %min1 : i32 func.return %max2 : i32 @@ -2055,7 +2055,7 @@ func.func @test_maxui2(%arg0 : i8) -> (i8, i8, i8, i8) { // CHECK-LABEL: foldMaxuiMaxui1 // CHECK: %[[MAXUI:.*]] = arith.maxui %arg1, %arg0 : i32 // CHECK: return %[[MAXUI]] : i32 -func.func public @foldMaxuiMaxui1(%arg0: i32, %arg1: i32) -> i32 { +func.func @foldMaxuiMaxui1(%arg0: i32, %arg1: i32) -> i32 { %max1 = arith.maxui %arg1, %arg0 : i32 %max2 = arith.maxui %max1, %arg1 : i32 func.return %max2 : i32 @@ -2064,7 +2064,7 @@ func.func public @foldMaxuiMaxui1(%arg0: i32, %arg1: i32) -> i32 { // CHECK-LABEL: foldMaxuiMaxui2 // CHECK: %[[MAXUI:.*]] = arith.maxui %arg1, %arg0 : i32 // CHECK: return %[[MAXUI]] : i32 -func.func public @foldMaxuiMaxui2(%arg0: i32, %arg1: i32) -> i32 { +func.func @foldMaxuiMaxui2(%arg0: i32, %arg1: i32) -> i32 { %max1 = arith.maxui %arg1, %arg0 : i32 %max2 = arith.maxui %arg1, %max1 : i32 func.return %max2 : i32 @@ -2072,7 +2072,7 @@ func.func public @foldMaxuiMaxui2(%arg0: i32, %arg1: i32) -> i32 { // CHECK-LABEL: foldMaxuiMinui1 // CHECK: return %arg0 : i32 -func.func public @foldMaxuiMinui1(%arg0: i32, %arg1: i32) -> i32 { +func.func @foldMaxuiMinui1(%arg0: i32, %arg1: i32) -> i32 { %min1 = arith.minui %arg1, %arg0 : i32 %max2 = arith.maxui %min1, %arg0 : i32 func.return %max2 : i32 @@ -2080,7 +2080,7 @@ func.func public @foldMaxuiMinui1(%arg0: i32, %arg1: i32) -> i32 { // CHECK-LABEL: foldMaxuiMinui2 // CHECK: return %arg0 : i32 -func.func public @foldMaxuiMinui2(%arg0: i32, %arg1: i32) -> i32 { +func.func @foldMaxuiMinui2(%arg0: i32, %arg1: i32) -> i32 { %min1 = arith.minui %arg1, %arg0 : i32 %max2 = arith.maxui %arg0, %min1 : i32 func.return %max2 : i32 @@ -2123,7 +2123,7 @@ func.func @test_minsi2(%arg0 : i8) -> (i8, i8, i8, i8) { // CHECK-LABEL: foldMinsiMinsi1 // CHECK: %[[MINSI:.*]] = arith.minsi %arg1, %arg0 : i32 // CHECK: return %[[MINSI]] : i32 -func.func public @foldMinsiMinsi1(%arg0: i32, %arg1: i32) -> i32 { +func.func @foldMinsiMinsi1(%arg0: i32, %arg1: i32) -> i32 { %min1 = arith.minsi %arg1, %arg0 : i32 %min2 = arith.minsi %min1, %arg1 : i32 func.return %min2 : i32 @@ -2132,7 +2132,7 @@ func.func public @foldMinsiMinsi1(%arg0: i32, %arg1: i32) -> i32 { // CHECK-LABEL: foldMinsiMinsi2 // CHECK: %[[MINSI:.*]] = arith.minsi %arg1, %arg0 : i32 // CHECK: return %[[MINSI]] : i32 -func.func public @foldMinsiMinsi2(%arg0: i32, %arg1: i32) -> i32 { +func.func @foldMinsiMinsi2(%arg0: i32, %arg1: i32) -> i32 { %min1 = arith.minsi %arg1, %arg0 : i32 %min2 = arith.minsi %arg1, %min1 : i32 func.return %min2 : i32 @@ -2140,7 +2140,7 @@ func.func public @foldMinsiMinsi2(%arg0: i32, %arg1: i32) -> i32 { // CHECK-LABEL: foldMinsiMaxsi1 // CHECK: return %arg0 : i32 -func.func public @foldMinsiMaxsi1(%arg0: i32, %arg1: i32) -> i32 { +func.func @foldMinsiMaxsi1(%arg0: i32, %arg1: i32) -> i32 { %min1 = arith.maxsi %arg1, %arg0 : i32 %min2 = arith.minsi %min1, %arg0 : i32 func.return %min2 : i32 @@ -2148,7 +2148,7 @@ func.func public @foldMinsiMaxsi1(%arg0: i32, %arg1: i32) -> i32 { // CHECK-LABEL: foldMinsiMaxsi2 // CHECK: return %arg0 : i32 -func.func public @foldMinsiMaxsi2(%arg0: i32, %arg1: i32) -> i32 { +func.func @foldMinsiMaxsi2(%arg0: i32, %arg1: i32) -> i32 { %min1 = arith.maxsi %arg1, %arg0 : i32 %min2 = arith.minsi %arg0, %min1 : i32 func.return %min2 : i32 @@ -2191,7 +2191,7 @@ func.func @test_minui2(%arg0 : i8) -> (i8, i8, i8, i8) { // CHECK-LABEL: foldMinuiMinui1 // CHECK: %[[MINUI:.*]] = arith.minui %arg1, %arg0 : i32 // CHECK: return %[[MINUI]] : i32 -func.func public @foldMinuiMinui1(%arg0: i32, %arg1: i32) -> i32 { +func.func @foldMinuiMinui1(%arg0: i32, %arg1: i32) -> i32 { %min1 = arith.minui %arg1, %arg0 : i32 %min2 = arith.minui %min1, %arg1 : i32 func.return %min2 : i32 @@ -2200,7 +2200,7 @@ func.func public @foldMinuiMinui1(%arg0: i32, %arg1: i32) -> i32 { // CHECK-LABEL: foldMinuiMinui2 // CHECK: %[[MINUI:.*]] = arith.minui %arg1, %arg0 : i32 // CHECK: return %[[MINUI]] : i32 -func.func public @foldMinuiMinui2(%arg0: i32, %arg1: i32) -> i32 { +func.func @foldMinuiMinui2(%arg0: i32, %arg1: i32) -> i32 { %min1 = arith.minui %arg1, %arg0 : i32 %min2 = arith.minui %arg1, %min1 : i32 func.return %min2 : i32 @@ -2208,7 +2208,7 @@ func.func public @foldMinuiMinui2(%arg0: i32, %arg1: i32) -> i32 { // CHECK-LABEL: foldMinuiMaxui1 // CHECK: return %arg0 : i32 -func.func public @foldMinuiMaxui1(%arg0: i32, %arg1: i32) -> i32 { +func.func @foldMinuiMaxui1(%arg0: i32, %arg1: i32) -> i32 { %max1 = arith.maxui %arg1, %arg0 : i32 %min2 = arith.minui %max1, %arg0 : i32 func.return %min2 : i32 @@ -2216,7 +2216,7 @@ func.func public @foldMinuiMaxui1(%arg0: i32, %arg1: i32) -> i32 { // CHECK-LABEL: foldMinuiMaxui2 // CHECK: return %arg0 : i32 -func.func public @foldMinuiMaxui2(%arg0: i32, %arg1: i32) -> i32 { +func.func @foldMinuiMaxui2(%arg0: i32, %arg1: i32) -> i32 { %max1 = arith.maxui %arg1, %arg0 : i32 %min2 = arith.minui %arg0, %max1 : i32 func.return %min2 : i32 From d9cd68f2ac93318302754a29f8395cc0e726938f Mon Sep 17 00:00:00 2001 From: Ziliang Zhang Date: Sun, 28 Sep 2025 15:07:39 +0800 Subject: [PATCH 3/4] Support maximumf/minimumf/maxnumf/minnumf --- mlir/lib/Dialect/Arith/IR/ArithOps.cpp | 48 +++++++++++++++ mlir/test/Dialect/Arith/canonicalize.mlir | 73 +++++++++++++++++++++++ 2 files changed, 121 insertions(+) diff --git a/mlir/lib/Dialect/Arith/IR/ArithOps.cpp b/mlir/lib/Dialect/Arith/IR/ArithOps.cpp index ea95d15b96f0c..1f461f48ebf7e 100644 --- a/mlir/lib/Dialect/Arith/IR/ArithOps.cpp +++ b/mlir/lib/Dialect/Arith/IR/ArithOps.cpp @@ -1116,6 +1116,18 @@ OpFoldResult arith::MaximumFOp::fold(FoldAdaptor adaptor) { if (matchPattern(adaptor.getRhs(), m_NegInfFloat())) return getLhs(); + // max(max(a, b), b) -> max(a, b) + // max(max(a, b), a) -> max(a, b) + if (auto max = getLhs().getDefiningOp()) + if (getRhs() == max.getRhs() || getRhs() == max.getLhs()) + return getLhs(); + + // max(a, max(a, b)) -> max(a, b) + // max(b, max(a, b)) -> max(a, b) + if (auto max = getRhs().getDefiningOp()) + if (getLhs() == max.getRhs() || getLhs() == max.getLhs()) + return getRhs(); + return constFoldBinaryOp( adaptor.getOperands(), [](const APFloat &a, const APFloat &b) { return llvm::maximum(a, b); }); @@ -1134,6 +1146,18 @@ OpFoldResult arith::MaxNumFOp::fold(FoldAdaptor adaptor) { if (matchPattern(adaptor.getRhs(), m_NaNFloat())) return getLhs(); + // max(max(a, b), b) -> max(a, b) + // max(max(a, b), a) -> max(a, b) + if (auto max = getLhs().getDefiningOp()) + if (getRhs() == max.getRhs() || getRhs() == max.getLhs()) + return getLhs(); + + // max(a, max(a, b)) -> max(a, b) + // max(b, max(a, b)) -> max(a, b) + if (auto max = getRhs().getDefiningOp()) + if (getLhs() == max.getRhs() || getLhs() == max.getLhs()) + return getRhs(); + return constFoldBinaryOp(adaptor.getOperands(), llvm::maxnum); } @@ -1248,6 +1272,18 @@ OpFoldResult arith::MinimumFOp::fold(FoldAdaptor adaptor) { if (matchPattern(adaptor.getRhs(), m_PosInfFloat())) return getLhs(); + // min(min(a, b), b) -> min(a, b) + // min(min(a, b), a) -> min(a, b) + if (auto min = getLhs().getDefiningOp()) + if (getRhs() == min.getRhs() || getRhs() == min.getLhs()) + return getLhs(); + + // min(a, min(a, b)) -> min(a, b) + // min(b, min(a, b)) -> min(a, b) + if (auto min = getRhs().getDefiningOp()) + if (getLhs() == min.getRhs() || getLhs() == min.getLhs()) + return getRhs(); + return constFoldBinaryOp( adaptor.getOperands(), [](const APFloat &a, const APFloat &b) { return llvm::minimum(a, b); }); @@ -1266,6 +1302,18 @@ OpFoldResult arith::MinNumFOp::fold(FoldAdaptor adaptor) { if (matchPattern(adaptor.getRhs(), m_NaNFloat())) return getLhs(); + // min(min(a, b), b) -> min(a, b) + // min(min(a, b), a) -> min(a, b) + if (auto min = getLhs().getDefiningOp()) + if (getRhs() == min.getRhs() || getRhs() == min.getLhs()) + return getLhs(); + + // min(a, min(a, b)) -> min(a, b) + // min(b, min(a, b)) -> min(a, b) + if (auto min = getRhs().getDefiningOp()) + if (getLhs() == min.getRhs() || getLhs() == min.getLhs()) + return getRhs(); + return constFoldBinaryOp( adaptor.getOperands(), [](const APFloat &a, const APFloat &b) { return llvm::minnum(a, b); }); diff --git a/mlir/test/Dialect/Arith/canonicalize.mlir b/mlir/test/Dialect/Arith/canonicalize.mlir index 3eba98670d325..aa1e30e9c9c39 100644 --- a/mlir/test/Dialect/Arith/canonicalize.mlir +++ b/mlir/test/Dialect/Arith/canonicalize.mlir @@ -2237,6 +2237,24 @@ func.func @test_minimumf(%arg0 : f32) -> (f32, f32, f32) { return %0, %1, %2 : f32, f32, f32 } +// CHECK-LABEL: foldMinimumfMinimumf1 +// CHECK: %[[MINF:.*]] = arith.minimumf %arg1, %arg0 : f32 +// CHECK: return %[[MINF]] : f32 +func.func public @foldMinimumfMinimumf1(%arg0: f32, %arg1: f32) -> f32 { + %min1 = arith.minimumf %arg1, %arg0 : f32 + %min2 = arith.minimumf %min1, %arg1 : f32 + func.return %min2 : f32 +} + +// CHECK-LABEL: foldMinimumfMinimumf2 +// CHECK: %[[MINF:.*]] = arith.minimumf %arg1, %arg0 : f32 +// CHECK: return %[[MINF]] : f32 +func.func public @foldMinimumfMinimumf2(%arg0: f32, %arg1: f32) -> f32 { + %min1 = arith.minimumf %arg1, %arg0 : f32 + %min2 = arith.minimumf %arg1, %min1 : f32 + func.return %min2 : f32 +} + // ----- // CHECK-LABEL: @test_maximumf( @@ -2252,6 +2270,24 @@ func.func @test_maximumf(%arg0 : f32) -> (f32, f32, f32) { return %0, %1, %2 : f32, f32, f32 } +// CHECK-LABEL: foldMaximumfMaximumf1 +// CHECK: %[[MAXF:.*]] = arith.maximumf %arg1, %arg0 : f32 +// CHECK: return %[[MAXF]] : f32 +func.func public @foldMaximumfMaximumf1(%arg0: f32, %arg1: f32) -> f32 { + %max1 = arith.maximumf %arg1, %arg0 : f32 + %max2 = arith.maximumf %max1, %arg1 : f32 + func.return %max2 : f32 +} + +// CHECK-LABEL: foldMaximumfMaximumf2 +// CHECK: %[[MAXF:.*]] = arith.maximumf %arg1, %arg0 : f32 +// CHECK: return %[[MAXF]] : f32 +func.func public @foldMaximumfMaximumf2(%arg0: f32, %arg1: f32) -> f32 { + %max1 = arith.maximumf %arg1, %arg0 : f32 + %max2 = arith.maximumf %arg1, %max1 : f32 + func.return %max2 : f32 +} + // ----- // CHECK-LABEL: @test_minnumf( @@ -2271,6 +2307,25 @@ func.func @test_minnumf(%arg0 : f32) -> (f32, f32, f32, f32) { return %0, %1, %2, %3 : f32, f32, f32, f32 } +// CHECK-LABEL: foldMinnumfMinnumf1 +// CHECK: %[[MINF:.*]] = arith.minnumf %arg1, %arg0 : f32 +// CHECK: return %[[MINF]] : f32 +func.func public @foldMinnumfMinnumf1(%arg0: f32, %arg1: f32) -> f32 { + %min1 = arith.minnumf %arg1, %arg0 : f32 + %min2 = arith.minnumf %min1, %arg1 : f32 + func.return %min2 : f32 +} + +// CHECK-LABEL: foldMinnumfMinnumf2 +// CHECK: %[[MINF:.*]] = arith.minnumf %arg1, %arg0 : f32 +// CHECK: return %[[MINF]] : f32 +func.func public @foldMinnumfMinnumf2(%arg0: f32, %arg1: f32) -> f32 { + %min1 = arith.minnumf %arg1, %arg0 : f32 + %min2 = arith.minnumf %arg1, %min1 : f32 + func.return %min2 : f32 +} + + // ----- // CHECK-LABEL: @test_maxnumf( @@ -2290,6 +2345,24 @@ func.func @test_maxnumf(%arg0 : f32) -> (f32, f32, f32, f32) { return %0, %1, %2, %3 : f32, f32, f32, f32 } +// CHECK-LABEL: foldMaxnumfMaxnumf1 +// CHECK: %[[MAXF:.*]] = arith.maxnumf %arg1, %arg0 : f32 +// CHECK: return %[[MAXF]] : f32 +func.func public @foldMaxnumfMaxnumf1(%arg0: f32, %arg1: f32) -> f32 { + %max1 = arith.maxnumf %arg1, %arg0 : f32 + %max2 = arith.maxnumf %max1, %arg1 : f32 + func.return %max2 : f32 +} + +// CHECK-LABEL: foldMaxnumfMaxnumf2 +// CHECK: %[[MAXF:.*]] = arith.maxnumf %arg1, %arg0 : f32 +// CHECK: return %[[MAXF]] : f32 +func.func public @foldMaxnumfMaxnumf2(%arg0: f32, %arg1: f32) -> f32 { + %max1 = arith.maxnumf %arg1, %arg0 : f32 + %max2 = arith.maxnumf %arg1, %max1 : f32 + func.return %max2 : f32 +} + // ----- // CHECK-LABEL: @test_addf( From 1b0b94b82347245c2354a2e54586899800d347d2 Mon Sep 17 00:00:00 2001 From: Ziliang Zhang Date: Mon, 29 Sep 2025 11:02:06 +0800 Subject: [PATCH 4/4] Remove public --- mlir/test/Dialect/Arith/canonicalize.mlir | 16 ++++++++-------- 1 file changed, 8 insertions(+), 8 deletions(-) diff --git a/mlir/test/Dialect/Arith/canonicalize.mlir b/mlir/test/Dialect/Arith/canonicalize.mlir index aa1e30e9c9c39..4e512dd008c27 100644 --- a/mlir/test/Dialect/Arith/canonicalize.mlir +++ b/mlir/test/Dialect/Arith/canonicalize.mlir @@ -2240,7 +2240,7 @@ func.func @test_minimumf(%arg0 : f32) -> (f32, f32, f32) { // CHECK-LABEL: foldMinimumfMinimumf1 // CHECK: %[[MINF:.*]] = arith.minimumf %arg1, %arg0 : f32 // CHECK: return %[[MINF]] : f32 -func.func public @foldMinimumfMinimumf1(%arg0: f32, %arg1: f32) -> f32 { +func.func @foldMinimumfMinimumf1(%arg0: f32, %arg1: f32) -> f32 { %min1 = arith.minimumf %arg1, %arg0 : f32 %min2 = arith.minimumf %min1, %arg1 : f32 func.return %min2 : f32 @@ -2249,7 +2249,7 @@ func.func public @foldMinimumfMinimumf1(%arg0: f32, %arg1: f32) -> f32 { // CHECK-LABEL: foldMinimumfMinimumf2 // CHECK: %[[MINF:.*]] = arith.minimumf %arg1, %arg0 : f32 // CHECK: return %[[MINF]] : f32 -func.func public @foldMinimumfMinimumf2(%arg0: f32, %arg1: f32) -> f32 { +func.func @foldMinimumfMinimumf2(%arg0: f32, %arg1: f32) -> f32 { %min1 = arith.minimumf %arg1, %arg0 : f32 %min2 = arith.minimumf %arg1, %min1 : f32 func.return %min2 : f32 @@ -2273,7 +2273,7 @@ func.func @test_maximumf(%arg0 : f32) -> (f32, f32, f32) { // CHECK-LABEL: foldMaximumfMaximumf1 // CHECK: %[[MAXF:.*]] = arith.maximumf %arg1, %arg0 : f32 // CHECK: return %[[MAXF]] : f32 -func.func public @foldMaximumfMaximumf1(%arg0: f32, %arg1: f32) -> f32 { +func.func @foldMaximumfMaximumf1(%arg0: f32, %arg1: f32) -> f32 { %max1 = arith.maximumf %arg1, %arg0 : f32 %max2 = arith.maximumf %max1, %arg1 : f32 func.return %max2 : f32 @@ -2282,7 +2282,7 @@ func.func public @foldMaximumfMaximumf1(%arg0: f32, %arg1: f32) -> f32 { // CHECK-LABEL: foldMaximumfMaximumf2 // CHECK: %[[MAXF:.*]] = arith.maximumf %arg1, %arg0 : f32 // CHECK: return %[[MAXF]] : f32 -func.func public @foldMaximumfMaximumf2(%arg0: f32, %arg1: f32) -> f32 { +func.func @foldMaximumfMaximumf2(%arg0: f32, %arg1: f32) -> f32 { %max1 = arith.maximumf %arg1, %arg0 : f32 %max2 = arith.maximumf %arg1, %max1 : f32 func.return %max2 : f32 @@ -2310,7 +2310,7 @@ func.func @test_minnumf(%arg0 : f32) -> (f32, f32, f32, f32) { // CHECK-LABEL: foldMinnumfMinnumf1 // CHECK: %[[MINF:.*]] = arith.minnumf %arg1, %arg0 : f32 // CHECK: return %[[MINF]] : f32 -func.func public @foldMinnumfMinnumf1(%arg0: f32, %arg1: f32) -> f32 { +func.func @foldMinnumfMinnumf1(%arg0: f32, %arg1: f32) -> f32 { %min1 = arith.minnumf %arg1, %arg0 : f32 %min2 = arith.minnumf %min1, %arg1 : f32 func.return %min2 : f32 @@ -2319,7 +2319,7 @@ func.func public @foldMinnumfMinnumf1(%arg0: f32, %arg1: f32) -> f32 { // CHECK-LABEL: foldMinnumfMinnumf2 // CHECK: %[[MINF:.*]] = arith.minnumf %arg1, %arg0 : f32 // CHECK: return %[[MINF]] : f32 -func.func public @foldMinnumfMinnumf2(%arg0: f32, %arg1: f32) -> f32 { +func.func @foldMinnumfMinnumf2(%arg0: f32, %arg1: f32) -> f32 { %min1 = arith.minnumf %arg1, %arg0 : f32 %min2 = arith.minnumf %arg1, %min1 : f32 func.return %min2 : f32 @@ -2348,7 +2348,7 @@ func.func @test_maxnumf(%arg0 : f32) -> (f32, f32, f32, f32) { // CHECK-LABEL: foldMaxnumfMaxnumf1 // CHECK: %[[MAXF:.*]] = arith.maxnumf %arg1, %arg0 : f32 // CHECK: return %[[MAXF]] : f32 -func.func public @foldMaxnumfMaxnumf1(%arg0: f32, %arg1: f32) -> f32 { +func.func @foldMaxnumfMaxnumf1(%arg0: f32, %arg1: f32) -> f32 { %max1 = arith.maxnumf %arg1, %arg0 : f32 %max2 = arith.maxnumf %max1, %arg1 : f32 func.return %max2 : f32 @@ -2357,7 +2357,7 @@ func.func public @foldMaxnumfMaxnumf1(%arg0: f32, %arg1: f32) -> f32 { // CHECK-LABEL: foldMaxnumfMaxnumf2 // CHECK: %[[MAXF:.*]] = arith.maxnumf %arg1, %arg0 : f32 // CHECK: return %[[MAXF]] : f32 -func.func public @foldMaxnumfMaxnumf2(%arg0: f32, %arg1: f32) -> f32 { +func.func @foldMaxnumfMaxnumf2(%arg0: f32, %arg1: f32) -> f32 { %max1 = arith.maxnumf %arg1, %arg0 : f32 %max2 = arith.maxnumf %arg1, %max1 : f32 func.return %max2 : f32