Skip to content

Commit d6eccaa

Browse files
committed
Fix attention mask to use float_lowest instead of -inf and add unit test for softmax NaN case
1 parent 8a94ad6 commit d6eccaa

File tree

2 files changed

+12
-1
lines changed

2 files changed

+12
-1
lines changed

onnxscript/function_libs/torch_lib/ops/nn.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414

1515
from __future__ import annotations
1616

17+
import numpy as np
1718
import math
1819
from typing import Optional, Sequence, Tuple, TypeVar, Union
1920

@@ -2048,6 +2049,9 @@ def _aten_scaled_dot_product_attention_no_mask_onnx(
20482049
attn_weight, _ = op.Dropout(attn_weight, dropout_p)
20492050
return op.MatMul(attn_weight, value)
20502051

2052+
def float_lowest(dtype):
2053+
"""Returns the lowest representable value for the given numpy dtype."""
2054+
return np.finfo(np.dtype(dtype)).min
20512055

20522056
def _aten_scaled_dot_product_attention_bool_mask_onnx(
20532057
query: TFloat,
@@ -2078,7 +2082,7 @@ def _aten_scaled_dot_product_attention_bool_mask_onnx(
20782082
key_transposed_scaled = op.Mul(key_transposed, op.Sqrt(scale))
20792083
# Turn the Boolean mask to float: attn_mask.masked_fill(not attn_mask, -float('inf'))
20802084
zero = op.Constant(value=ir.tensor(0.0, dtype=query.dtype))
2081-
neg_inf = op.Constant(value=ir.tensor(-float("inf"), dtype=query.dtype))
2085+
neg_inf = op.Constant(value=ir.tensor(float_lowest(query.dtype)), dtype=query.dtype)
20822086
attn_mask = op.Where(attn_mask, zero, neg_inf)
20832087
attn_weight = op.Softmax(
20842088
op.Add(op.MatMul(query_scaled, key_transposed_scaled), attn_mask),

tests/common/testutils.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
import torch
1515

1616
from onnxscript import optimizer
17+
from onnxscript.onnx_opset import opset18 as op
1718
from onnxscript.rewriter import onnxruntime as ort_rewriter
1819
from onnxscript.utils import evaluation_utils
1920

@@ -101,3 +102,9 @@ def test_onnxruntime_rewrite(
101102
f"Failed for model {model_name} and output {i} with rtol={rtol} and atol={atol}\n{e}"
102103
)
103104
raise
105+
106+
def test_softmax_with_all_inf_mask():
107+
# GH #2561
108+
input = np.array([[-float("inf"), -float("inf")]], dtype=np.float32)
109+
output = op.Softmax(input, axis=-1)
110+
assert np.isnan(output).all(), "Softmax should return NaN when all inputs are -inf"

0 commit comments

Comments
 (0)