diff --git a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp index d9840e3923c4f..133855cc38933 100644 --- a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp +++ b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp @@ -2892,7 +2892,7 @@ FailureOr> SoftmaxOp::decomposeOperation(OpBuilder &b) { dims.erase(dims.begin() + reductionDim); // Step 1: Compute max along dim. Value outputReduce = b.create(loc, dims, elementType); - Value neutralForMaxF = arith::getIdentityValue(arith::AtomicRMWKind::maximumf, + Value neutralForMaxF = arith::getIdentityValue(arith::AtomicRMWKind::maxnumf, elementType, b, loc, /*useOnlyFiniteValue=*/true); Value neutralForMaxFInit = diff --git a/mlir/test/Dialect/Linalg/transform-op-decompose.mlir b/mlir/test/Dialect/Linalg/transform-op-decompose.mlir index 2e211d2fa7dbe..72acf43361f50 100644 --- a/mlir/test/Dialect/Linalg/transform-op-decompose.mlir +++ b/mlir/test/Dialect/Linalg/transform-op-decompose.mlir @@ -210,7 +210,7 @@ func.func @softmax(%arg0: tensor<2x16x32xf32>, %dst: tensor<2x16x32xf32>) -> ten // CHECK-LABEL: func.func @softmax( // CHECK-SAME: %[[ARG0:[a-zA-Z0-9_]+]]: tensor<2x16x32xf32>, %[[DST:[a-zA-Z0-9_]+]]: tensor<2x16x32xf32>) -> tensor<2x16x32xf32> { // CHECK-DAG: %[[D1:.+]] = tensor.empty() : tensor<2x16xf32> -// CHECK-DAG: %[[CST:.+]] = arith.constant -3.40282347E+38 : f32 +// CHECK-DAG: %[[CST:.+]] = arith.constant 0xFFC00000 : f32 // CHECK: %[[D2:.+]] = linalg.fill ins(%[[CST]] : f32) outs(%[[D1]] : tensor<2x16xf32>) -> tensor<2x16xf32> // CHECK: %[[D3:.+]] = linalg.generic {indexing_maps = [#[[$MAP]], #[[$MAP1]]], iterator_types = ["parallel", // CHECK-SAME: "parallel", "reduction"]} ins(%[[ARG0]] : tensor<2x16x32xf32>) outs(%[[D2]] : tensor<2x16xf32>) {