Skip to content

Commit e5d957e

Browse files
committed
[mlir][arith] Fold min/max ops using absorption law and redundant consecutive ops
Supported folding for arith.maxsi, arith.maxui, arith.minsi, and arith.minui. 1. Fold redundant consecutive min/max operations: max(max(a, b), b) -> max(a, b) max(max(a, b), a) -> max(a, b) max(a, max(a, b)) -> max(a, b) max(b, max(a, b)) -> max(a, b) (similar cases for min) 2. Fold using the absorption law: max(min(a, b), a) -> a max(min(b, a), a) -> a max(a, min(a, b)) -> a max(a, min(b, a)) -> a (similar cases for min)
1 parent 322b990 commit e5d957e

File tree

2 files changed

+232
-1
lines changed

2 files changed

+232
-1
lines changed

mlir/lib/Dialect/Arith/IR/ArithOps.cpp

Lines changed: 96 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1156,6 +1156,30 @@ OpFoldResult MaxSIOp::fold(FoldAdaptor adaptor) {
11561156
return getLhs();
11571157
}
11581158

1159+
// max(max(a, b), b) -> max(a, b)
1160+
// max(max(a, b), a) -> max(a, b)
1161+
if (auto max = getLhs().getDefiningOp<MaxSIOp>())
1162+
if (getRhs() == max.getRhs() || getRhs() == max.getLhs())
1163+
return getLhs();
1164+
1165+
// max(a, max(a, b)) -> max(a, b)
1166+
// max(b, max(a, b)) -> max(a, b)
1167+
if (auto max = getRhs().getDefiningOp<MaxSIOp>())
1168+
if (getLhs() == max.getRhs() || getLhs() == max.getLhs())
1169+
return getRhs();
1170+
1171+
// max(min(a, b), a) -> a
1172+
// max(min(b, a), a) -> a
1173+
if (auto min = getLhs().getDefiningOp<MinSIOp>())
1174+
if (getRhs() == min.getRhs() || getRhs() == min.getLhs())
1175+
return getRhs();
1176+
1177+
// max(a, min(a, b)) -> a
1178+
// max(a, min(b, a)) -> a
1179+
if (auto min = getRhs().getDefiningOp<MinSIOp>())
1180+
if (getLhs() == min.getRhs() || getLhs() == min.getLhs())
1181+
return getLhs();
1182+
11591183
return constFoldBinaryOp<IntegerAttr>(adaptor.getOperands(),
11601184
[](const APInt &a, const APInt &b) {
11611185
return llvm::APIntOps::smax(a, b);
@@ -1181,6 +1205,30 @@ OpFoldResult MaxUIOp::fold(FoldAdaptor adaptor) {
11811205
return getLhs();
11821206
}
11831207

1208+
// max(max(a, b), b) -> max(a, b)
1209+
// max(max(a, b), a) -> max(a, b)
1210+
if (auto max = getLhs().getDefiningOp<MaxUIOp>())
1211+
if (getRhs() == max.getRhs() || getRhs() == max.getLhs())
1212+
return getLhs();
1213+
1214+
// max(a, max(a, b)) -> max(a, b)
1215+
// max(b, max(a, b)) -> max(a, b)
1216+
if (auto max = getRhs().getDefiningOp<MaxUIOp>())
1217+
if (getLhs() == max.getRhs() || getLhs() == max.getLhs())
1218+
return getRhs();
1219+
1220+
// max(min(a, b), a) -> a
1221+
// max(min(b, a), a) -> a
1222+
if (auto min = getLhs().getDefiningOp<MinUIOp>())
1223+
if (getRhs() == min.getRhs() || getRhs() == min.getLhs())
1224+
return getRhs();
1225+
1226+
// max(a, min(a, b)) -> a
1227+
// max(a, min(b, a)) -> a
1228+
if (auto min = getRhs().getDefiningOp<MinUIOp>())
1229+
if (getLhs() == min.getRhs() || getLhs() == min.getLhs())
1230+
return getLhs();
1231+
11841232
return constFoldBinaryOp<IntegerAttr>(adaptor.getOperands(),
11851233
[](const APInt &a, const APInt &b) {
11861234
return llvm::APIntOps::umax(a, b);
@@ -1242,6 +1290,30 @@ OpFoldResult MinSIOp::fold(FoldAdaptor adaptor) {
12421290
return getLhs();
12431291
}
12441292

1293+
// min(min(a, b), b) -> min(a, b)
1294+
// min(min(a, b), a) -> min(a, b)
1295+
if (auto min = getLhs().getDefiningOp<MinSIOp>())
1296+
if (getRhs() == min.getRhs() || getRhs() == min.getLhs())
1297+
return getLhs();
1298+
1299+
// min(a, min(a, b)) -> min(a, b)
1300+
// min(b, min(a, b)) -> min(a, b)
1301+
if (auto min = getRhs().getDefiningOp<MinSIOp>())
1302+
if (getLhs() == min.getRhs() || getLhs() == min.getLhs())
1303+
return getRhs();
1304+
1305+
// min(max(a, b), a) -> a
1306+
// min(max(b, a), a) -> a
1307+
if (auto max = getLhs().getDefiningOp<MaxSIOp>())
1308+
if (getRhs() == max.getRhs() || getRhs() == max.getLhs())
1309+
return getRhs();
1310+
1311+
// min(a, max(a, b)) -> a
1312+
// min(a, max(b, a)) -> a
1313+
if (auto max = getRhs().getDefiningOp<MaxSIOp>())
1314+
if (getLhs() == max.getRhs() || getLhs() == max.getLhs())
1315+
return getLhs();
1316+
12451317
return constFoldBinaryOp<IntegerAttr>(adaptor.getOperands(),
12461318
[](const APInt &a, const APInt &b) {
12471319
return llvm::APIntOps::smin(a, b);
@@ -1267,6 +1339,30 @@ OpFoldResult MinUIOp::fold(FoldAdaptor adaptor) {
12671339
return getLhs();
12681340
}
12691341

1342+
// min(min(a, b), b) -> min(a, b)
1343+
// min(min(a, b), a) -> min(a, b)
1344+
if (auto min = getLhs().getDefiningOp<MinUIOp>())
1345+
if (getRhs() == min.getRhs() || getRhs() == min.getLhs())
1346+
return getLhs();
1347+
1348+
// min(a, min(a, b)) -> min(a, b)
1349+
// min(b, min(a, b)) -> min(a, b)
1350+
if (auto min = getRhs().getDefiningOp<MinUIOp>())
1351+
if (getLhs() == min.getRhs() || getLhs() == min.getLhs())
1352+
return getRhs();
1353+
1354+
// min(max(a, b), a) -> a
1355+
// min(max(b, a), a) -> a
1356+
if (auto max = getLhs().getDefiningOp<MaxUIOp>())
1357+
if (getRhs() == max.getRhs() || getRhs() == max.getLhs())
1358+
return getRhs();
1359+
1360+
// min(a, max(a, b)) -> a
1361+
// min(a, max(b, a)) -> a
1362+
if (auto max = getRhs().getDefiningOp<MaxUIOp>())
1363+
if (getLhs() == max.getRhs() || getLhs() == max.getLhs())
1364+
return getLhs();
1365+
12701366
return constFoldBinaryOp<IntegerAttr>(adaptor.getOperands(),
12711367
[](const APInt &a, const APInt &b) {
12721368
return llvm::APIntOps::umin(a, b);

mlir/test/Dialect/Arith/canonicalize.mlir

Lines changed: 136 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1984,6 +1984,40 @@ func.func @test_maxsi2(%arg0 : i8) -> (i8, i8, i8, i8) {
19841984
return %0, %1, %2, %3: i8, i8, i8, i8
19851985
}
19861986

1987+
// CHECK-LABEL: foldMaxsiMaxsi1
1988+
// CHECK: %[[MAXSI:.*]] = arith.maxsi %arg1, %arg0 : i32
1989+
// CHECK: return %[[MAXSI]] : i32
1990+
func.func public @foldMaxsiMaxsi1(%arg0: i32, %arg1: i32) -> i32 {
1991+
%max1 = arith.maxsi %arg1, %arg0 : i32
1992+
%max2 = arith.maxsi %max1, %arg1 : i32
1993+
func.return %max2 : i32
1994+
}
1995+
1996+
// CHECK-LABEL: foldMaxsiMaxsi2
1997+
// CHECK: %[[MAXSI:.*]] = arith.maxsi %arg1, %arg0 : i32
1998+
// CHECK: return %[[MAXSI]] : i32
1999+
func.func public @foldMaxsiMaxsi2(%arg0: i32, %arg1: i32) -> i32 {
2000+
%max1 = arith.maxsi %arg1, %arg0 : i32
2001+
%max2 = arith.maxsi %arg1, %max1 : i32
2002+
func.return %max2 : i32
2003+
}
2004+
2005+
// CHECK-LABEL: foldMaxsiMinsi1
2006+
// CHECK: return %arg0 : i32
2007+
func.func public @foldMaxsiMinsi1(%arg0: i32, %arg1: i32) -> i32 {
2008+
%min1 = arith.minsi %arg1, %arg0 : i32
2009+
%max2 = arith.maxsi %min1, %arg0 : i32
2010+
func.return %max2 : i32
2011+
}
2012+
2013+
// CHECK-LABEL: foldMaxsiMinsi2
2014+
// CHECK: return %arg0 : i32
2015+
func.func public @foldMaxsiMinsi2(%arg0: i32, %arg1: i32) -> i32 {
2016+
%min1 = arith.minsi %arg1, %arg0 : i32
2017+
%max2 = arith.maxsi %arg0, %min1 : i32
2018+
func.return %max2 : i32
2019+
}
2020+
19872021
// -----
19882022

19892023
// CHECK-LABEL: test_maxui
@@ -2018,6 +2052,40 @@ func.func @test_maxui2(%arg0 : i8) -> (i8, i8, i8, i8) {
20182052
return %0, %1, %2, %3: i8, i8, i8, i8
20192053
}
20202054

2055+
// CHECK-LABEL: foldMaxuiMaxui1
2056+
// CHECK: %[[MAXUI:.*]] = arith.maxui %arg1, %arg0 : i32
2057+
// CHECK: return %[[MAXUI]] : i32
2058+
func.func public @foldMaxuiMaxui1(%arg0: i32, %arg1: i32) -> i32 {
2059+
%max1 = arith.maxui %arg1, %arg0 : i32
2060+
%max2 = arith.maxui %max1, %arg1 : i32
2061+
func.return %max2 : i32
2062+
}
2063+
2064+
// CHECK-LABEL: foldMaxuiMaxui2
2065+
// CHECK: %[[MAXUI:.*]] = arith.maxui %arg1, %arg0 : i32
2066+
// CHECK: return %[[MAXUI]] : i32
2067+
func.func public @foldMaxuiMaxui2(%arg0: i32, %arg1: i32) -> i32 {
2068+
%max1 = arith.maxui %arg1, %arg0 : i32
2069+
%max2 = arith.maxui %arg1, %max1 : i32
2070+
func.return %max2 : i32
2071+
}
2072+
2073+
// CHECK-LABEL: foldMaxuiMinui1
2074+
// CHECK: return %arg0 : i32
2075+
func.func public @foldMaxuiMinui1(%arg0: i32, %arg1: i32) -> i32 {
2076+
%min1 = arith.minui %arg1, %arg0 : i32
2077+
%max2 = arith.maxui %min1, %arg0 : i32
2078+
func.return %max2 : i32
2079+
}
2080+
2081+
// CHECK-LABEL: foldMaxuiMinui2
2082+
// CHECK: return %arg0 : i32
2083+
func.func public @foldMaxuiMinui2(%arg0: i32, %arg1: i32) -> i32 {
2084+
%min1 = arith.minui %arg1, %arg0 : i32
2085+
%max2 = arith.maxui %arg0, %min1 : i32
2086+
func.return %max2 : i32
2087+
}
2088+
20212089
// -----
20222090

20232091
// CHECK-LABEL: test_minsi
@@ -2052,6 +2120,40 @@ func.func @test_minsi2(%arg0 : i8) -> (i8, i8, i8, i8) {
20522120
return %0, %1, %2, %3: i8, i8, i8, i8
20532121
}
20542122

2123+
// CHECK-LABEL: foldMinsiMinsi1
2124+
// CHECK: %[[MINSI:.*]] = arith.minsi %arg1, %arg0 : i32
2125+
// CHECK: return %[[MINSI]] : i32
2126+
func.func public @foldMinsiMinsi1(%arg0: i32, %arg1: i32) -> i32 {
2127+
%min1 = arith.minsi %arg1, %arg0 : i32
2128+
%min2 = arith.minsi %min1, %arg1 : i32
2129+
func.return %min2 : i32
2130+
}
2131+
2132+
// CHECK-LABEL: foldMinsiMinsi2
2133+
// CHECK: %[[MINSI:.*]] = arith.minsi %arg1, %arg0 : i32
2134+
// CHECK: return %[[MINSI]] : i32
2135+
func.func public @foldMinsiMinsi2(%arg0: i32, %arg1: i32) -> i32 {
2136+
%min1 = arith.minsi %arg1, %arg0 : i32
2137+
%min2 = arith.minsi %arg1, %min1 : i32
2138+
func.return %min2 : i32
2139+
}
2140+
2141+
// CHECK-LABEL: foldMinsiMaxsi1
2142+
// CHECK: return %arg0 : i32
2143+
func.func public @foldMinsiMaxsi1(%arg0: i32, %arg1: i32) -> i32 {
2144+
%min1 = arith.maxsi %arg1, %arg0 : i32
2145+
%min2 = arith.minsi %min1, %arg0 : i32
2146+
func.return %min2 : i32
2147+
}
2148+
2149+
// CHECK-LABEL: foldMinsiMaxsi2
2150+
// CHECK: return %arg0 : i32
2151+
func.func public @foldMinsiMaxsi2(%arg0: i32, %arg1: i32) -> i32 {
2152+
%min1 = arith.maxsi %arg1, %arg0 : i32
2153+
%min2 = arith.minsi %arg0, %min1 : i32
2154+
func.return %min2 : i32
2155+
}
2156+
20552157
// -----
20562158

20572159
// CHECK-LABEL: test_minui
@@ -2086,6 +2188,40 @@ func.func @test_minui2(%arg0 : i8) -> (i8, i8, i8, i8) {
20862188
return %0, %1, %2, %3: i8, i8, i8, i8
20872189
}
20882190

2191+
// CHECK-LABEL: foldMinuiMinui1
2192+
// CHECK: %[[MINUI:.*]] = arith.minui %arg1, %arg0 : i32
2193+
// CHECK: return %[[MINUI]] : i32
2194+
func.func public @foldMinuiMinui1(%arg0: i32, %arg1: i32) -> i32 {
2195+
%min1 = arith.minui %arg1, %arg0 : i32
2196+
%min2 = arith.minui %min1, %arg1 : i32
2197+
func.return %min2 : i32
2198+
}
2199+
2200+
// CHECK-LABEL: foldMinuiMinui2
2201+
// CHECK: %[[MINUI:.*]] = arith.minui %arg1, %arg0 : i32
2202+
// CHECK: return %[[MINUI]] : i32
2203+
func.func public @foldMinuiMinui2(%arg0: i32, %arg1: i32) -> i32 {
2204+
%min1 = arith.minui %arg1, %arg0 : i32
2205+
%min2 = arith.minui %arg1, %min1 : i32
2206+
func.return %min2 : i32
2207+
}
2208+
2209+
// CHECK-LABEL: foldMinuiMaxui1
2210+
// CHECK: return %arg0 : i32
2211+
func.func public @foldMinuiMaxui1(%arg0: i32, %arg1: i32) -> i32 {
2212+
%max1 = arith.maxui %arg1, %arg0 : i32
2213+
%min2 = arith.minui %max1, %arg0 : i32
2214+
func.return %min2 : i32
2215+
}
2216+
2217+
// CHECK-LABEL: foldMinuiMaxui2
2218+
// CHECK: return %arg0 : i32
2219+
func.func public @foldMinuiMaxui2(%arg0: i32, %arg1: i32) -> i32 {
2220+
%max1 = arith.maxui %arg1, %arg0 : i32
2221+
%min2 = arith.minui %arg0, %max1 : i32
2222+
func.return %min2 : i32
2223+
}
2224+
20892225
// -----
20902226

20912227
// CHECK-LABEL: @test_minimumf(
@@ -3377,4 +3513,3 @@ func.func @unreachable() {
33773513
%add = arith.addi %add, %c1_i64 : i64
33783514
cf.br ^unreachable
33793515
}
3380-

0 commit comments

Comments
 (0)