diff --git a/mlir/include/mlir/IR/Matchers.h b/mlir/include/mlir/IR/Matchers.h index 6fa5a47109d20..d218206e50f8f 100644 --- a/mlir/include/mlir/IR/Matchers.h +++ b/mlir/include/mlir/IR/Matchers.h @@ -417,6 +417,11 @@ inline detail::constant_float_predicate_matcher m_OneFloat() { }}; } +/// Matches a constant scalar / vector splat / tensor splat float ones. +inline detail::constant_float_predicate_matcher m_NaNFloat() { + return {[](const APFloat &value) { return value.isNaN(); }}; +} + /// Matches a constant scalar / vector splat / tensor splat float positive /// infinity. inline detail::constant_float_predicate_matcher m_PosInfFloat() { diff --git a/mlir/lib/Dialect/Arith/IR/ArithOps.cpp b/mlir/lib/Dialect/Arith/IR/ArithOps.cpp index 254f54d9e459e..ea74121261cc4 100644 --- a/mlir/lib/Dialect/Arith/IR/ArithOps.cpp +++ b/mlir/lib/Dialect/Arith/IR/ArithOps.cpp @@ -1014,13 +1014,11 @@ OpFoldResult arith::MaxNumFOp::fold(FoldAdaptor adaptor) { if (getLhs() == getRhs()) return getRhs(); - // maxnumf(x, -inf) -> x - if (matchPattern(adaptor.getRhs(), m_NegInfFloat())) + // maxnumf(x, NaN) -> x + if (matchPattern(adaptor.getRhs(), m_NaNFloat())) return getLhs(); - return constFoldBinaryOp( - adaptor.getOperands(), - [](const APFloat &a, const APFloat &b) { return llvm::maximum(a, b); }); + return constFoldBinaryOp(adaptor.getOperands(), llvm::maxnum); } //===----------------------------------------------------------------------===// @@ -1100,8 +1098,8 @@ OpFoldResult arith::MinNumFOp::fold(FoldAdaptor adaptor) { if (getLhs() == getRhs()) return getRhs(); - // minnumf(x, +inf) -> x - if (matchPattern(adaptor.getRhs(), m_PosInfFloat())) + // minnumf(x, NaN) -> x + if (matchPattern(adaptor.getRhs(), m_NaNFloat())) return getLhs(); return constFoldBinaryOp( diff --git a/mlir/test/Dialect/Arith/canonicalize.mlir b/mlir/test/Dialect/Arith/canonicalize.mlir index a386a178b7899..84f2b0f113a0c 100644 --- a/mlir/test/Dialect/Arith/canonicalize.mlir +++ b/mlir/test/Dialect/Arith/canonicalize.mlir @@ -1905,31 +1905,39 @@ func.func @test_maximumf(%arg0 : f32) -> (f32, f32, f32) { // ----- // CHECK-LABEL: @test_minnumf( -func.func @test_minnumf(%arg0 : f32) -> (f32, f32, f32) { +func.func @test_minnumf(%arg0 : f32) -> (f32, f32, f32, f32) { // CHECK-DAG: %[[C0:.+]] = arith.constant 0.0 + // CHECK-DAG: %[[INF:.+]] = arith.constant // CHECK-NEXT: %[[X:.+]] = arith.minnumf %arg0, %[[C0]] - // CHECK-NEXT: return %[[X]], %arg0, %arg0 + // CHECK-NEXT: %[[Y:.+]] = arith.minnumf %arg0, %[[INF]] + // CHECK-NEXT: return %[[X]], %arg0, %[[Y]], %arg0 %c0 = arith.constant 0.0 : f32 %inf = arith.constant 0x7F800000 : f32 + %nan = arith.constant 0x7FC00000 : f32 %0 = arith.minnumf %c0, %arg0 : f32 %1 = arith.minnumf %arg0, %arg0 : f32 %2 = arith.minnumf %inf, %arg0 : f32 - return %0, %1, %2 : f32, f32, f32 + %3 = arith.minnumf %nan, %arg0 : f32 + return %0, %1, %2, %3 : f32, f32, f32, f32 } // ----- // CHECK-LABEL: @test_maxnumf( -func.func @test_maxnumf(%arg0 : f32) -> (f32, f32, f32) { - // CHECK-DAG: %[[C0:.+]] = arith.constant +func.func @test_maxnumf(%arg0 : f32) -> (f32, f32, f32, f32) { + // CHECK-DAG: %[[C0:.+]] = arith.constant 0.0 + // CHECK-DAG: %[[NINF:.+]] = arith.constant // CHECK-NEXT: %[[X:.+]] = arith.maxnumf %arg0, %[[C0]] - // CHECK-NEXT: return %[[X]], %arg0, %arg0 + // CHECK-NEXT: %[[Y:.+]] = arith.maxnumf %arg0, %[[NINF]] + // CHECK-NEXT: return %[[X]], %arg0, %[[Y]], %arg0 %c0 = arith.constant 0.0 : f32 %-inf = arith.constant 0xFF800000 : f32 + %nan = arith.constant 0x7FC00000 : f32 %0 = arith.maxnumf %c0, %arg0 : f32 %1 = arith.maxnumf %arg0, %arg0 : f32 %2 = arith.maxnumf %-inf, %arg0 : f32 - return %0, %1, %2 : f32, f32, f32 + %3 = arith.maxnumf %nan, %arg0 : f32 + return %0, %1, %2, %3 : f32, f32, f32, f32 } // -----