Skip to content

Commit 98506b2

Browse files
committed
Fix neutral elt for softmax
1 parent 4ad4d34 commit 98506b2

File tree

2 files changed

+2
-2
lines changed

2 files changed

+2
-2
lines changed

mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2892,7 +2892,7 @@ FailureOr<SmallVector<Value>> SoftmaxOp::decomposeOperation(OpBuilder &b) {
28922892
dims.erase(dims.begin() + reductionDim);
28932893
// Step 1: Compute max along dim.
28942894
Value outputReduce = b.create<tensor::EmptyOp>(loc, dims, elementType);
2895-
Value neutralForMaxF = arith::getIdentityValue(arith::AtomicRMWKind::maximumf,
2895+
Value neutralForMaxF = arith::getIdentityValue(arith::AtomicRMWKind::maxnumf,
28962896
elementType, b, loc,
28972897
/*useOnlyFiniteValue=*/true);
28982898
Value neutralForMaxFInit =

mlir/test/Dialect/Linalg/transform-op-decompose.mlir

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -210,7 +210,7 @@ func.func @softmax(%arg0: tensor<2x16x32xf32>, %dst: tensor<2x16x32xf32>) -> ten
210210
// CHECK-LABEL: func.func @softmax(
211211
// CHECK-SAME: %[[ARG0:[a-zA-Z0-9_]+]]: tensor<2x16x32xf32>, %[[DST:[a-zA-Z0-9_]+]]: tensor<2x16x32xf32>) -> tensor<2x16x32xf32> {
212212
// CHECK-DAG: %[[D1:.+]] = tensor.empty() : tensor<2x16xf32>
213-
// CHECK-DAG: %[[CST:.+]] = arith.constant -3.40282347E+38 : f32
213+
// CHECK-DAG: %[[CST:.+]] = arith.constant 0xFFC00000 : f32
214214
// CHECK: %[[D2:.+]] = linalg.fill ins(%[[CST]] : f32) outs(%[[D1]] : tensor<2x16xf32>) -> tensor<2x16xf32>
215215
// CHECK: %[[D3:.+]] = linalg.generic {indexing_maps = [#[[$MAP]], #[[$MAP1]]], iterator_types = ["parallel",
216216
// CHECK-SAME: "parallel", "reduction"]} ins(%[[ARG0]] : tensor<2x16x32xf32>) outs(%[[D2]] : tensor<2x16xf32>) {

0 commit comments

Comments
 (0)