Skip to content

Commit 4d4a8b1

Browse files
committed
Fix arith maxnumf/minnumf folder
1 parent 4a505e1 commit 4d4a8b1

File tree

3 files changed

+25
-14
lines changed

3 files changed

+25
-14
lines changed

mlir/include/mlir/IR/Matchers.h

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -417,6 +417,11 @@ inline detail::constant_float_predicate_matcher m_OneFloat() {
417417
}};
418418
}
419419

420+
/// Matches a constant scalar / vector splat / tensor splat float ones.
421+
inline detail::constant_float_predicate_matcher m_NaNFloat() {
422+
return {[](const APFloat &value) { return value.isNaN(); }};
423+
}
424+
420425
/// Matches a constant scalar / vector splat / tensor splat float positive
421426
/// infinity.
422427
inline detail::constant_float_predicate_matcher m_PosInfFloat() {

mlir/lib/Dialect/Arith/IR/ArithOps.cpp

Lines changed: 5 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1014,13 +1014,11 @@ OpFoldResult arith::MaxNumFOp::fold(FoldAdaptor adaptor) {
10141014
if (getLhs() == getRhs())
10151015
return getRhs();
10161016

1017-
// maxnumf(x, -inf) -> x
1018-
if (matchPattern(adaptor.getRhs(), m_NegInfFloat()))
1017+
// maxnumf(x, NaN) -> x
1018+
if (matchPattern(adaptor.getRhs(), m_NaNFloat()))
10191019
return getLhs();
10201020

1021-
return constFoldBinaryOp<FloatAttr>(
1022-
adaptor.getOperands(),
1023-
[](const APFloat &a, const APFloat &b) { return llvm::maximum(a, b); });
1021+
return constFoldBinaryOp<FloatAttr>(adaptor.getOperands(), llvm::maxnum);
10241022
}
10251023

10261024
//===----------------------------------------------------------------------===//
@@ -1100,8 +1098,8 @@ OpFoldResult arith::MinNumFOp::fold(FoldAdaptor adaptor) {
11001098
if (getLhs() == getRhs())
11011099
return getRhs();
11021100

1103-
// minnumf(x, +inf) -> x
1104-
if (matchPattern(adaptor.getRhs(), m_PosInfFloat()))
1101+
// minnumf(x, NaN) -> x
1102+
if (matchPattern(adaptor.getRhs(), m_NaNFloat()))
11051103
return getLhs();
11061104

11071105
return constFoldBinaryOp<FloatAttr>(

mlir/test/Dialect/Arith/canonicalize.mlir

Lines changed: 15 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1905,31 +1905,39 @@ func.func @test_maximumf(%arg0 : f32) -> (f32, f32, f32) {
19051905
// -----
19061906

19071907
// CHECK-LABEL: @test_minnumf(
1908-
func.func @test_minnumf(%arg0 : f32) -> (f32, f32, f32) {
1908+
func.func @test_minnumf(%arg0 : f32) -> (f32, f32, f32, f32) {
19091909
// CHECK-DAG: %[[C0:.+]] = arith.constant 0.0
1910+
// CHECK-DAG: %[[INF:.+]] = arith.constant
19101911
// CHECK-NEXT: %[[X:.+]] = arith.minnumf %arg0, %[[C0]]
1911-
// CHECK-NEXT: return %[[X]], %arg0, %arg0
1912+
// CHECK-NEXT: %[[Y:.+]] = arith.minnumf %arg0, %[[INF]]
1913+
// CHECK-NEXT: return %[[X]], %arg0, %[[Y]], %arg0
19121914
%c0 = arith.constant 0.0 : f32
19131915
%inf = arith.constant 0x7F800000 : f32
1916+
%nan = arith.constant 0x7FC00000 : f32
19141917
%0 = arith.minnumf %c0, %arg0 : f32
19151918
%1 = arith.minnumf %arg0, %arg0 : f32
19161919
%2 = arith.minnumf %inf, %arg0 : f32
1917-
return %0, %1, %2 : f32, f32, f32
1920+
%3 = arith.minnumf %nan, %arg0 : f32
1921+
return %0, %1, %2, %3 : f32, f32, f32, f32
19181922
}
19191923

19201924
// -----
19211925

19221926
// CHECK-LABEL: @test_maxnumf(
1923-
func.func @test_maxnumf(%arg0 : f32) -> (f32, f32, f32) {
1924-
// CHECK-DAG: %[[C0:.+]] = arith.constant
1927+
func.func @test_maxnumf(%arg0 : f32) -> (f32, f32, f32, f32) {
1928+
// CHECK-DAG: %[[C0:.+]] = arith.constant 0.0
1929+
// CHECK-DAG: %[[NINF:.+]] = arith.constant
19251930
// CHECK-NEXT: %[[X:.+]] = arith.maxnumf %arg0, %[[C0]]
1926-
// CHECK-NEXT: return %[[X]], %arg0, %arg0
1931+
// CHECK-NEXT: %[[Y:.+]] = arith.maxnumf %arg0, %[[NINF]]
1932+
// CHECK-NEXT: return %[[X]], %arg0, %[[Y]], %arg0
19271933
%c0 = arith.constant 0.0 : f32
19281934
%-inf = arith.constant 0xFF800000 : f32
1935+
%nan = arith.constant 0x7FC00000 : f32
19291936
%0 = arith.maxnumf %c0, %arg0 : f32
19301937
%1 = arith.maxnumf %arg0, %arg0 : f32
19311938
%2 = arith.maxnumf %-inf, %arg0 : f32
1932-
return %0, %1, %2 : f32, f32, f32
1939+
%3 = arith.maxnumf %nan, %arg0 : f32
1940+
return %0, %1, %2, %3 : f32, f32, f32, f32
19331941
}
19341942

19351943
// -----

0 commit comments

Comments
 (0)