From 0068e40ef043d0079e7b8952c74b8925daad8330 Mon Sep 17 00:00:00 2001 From: IlyasMoutawwakil Date: Wed, 6 Aug 2025 22:03:44 +0200 Subject: [PATCH] match --- onnxscript/function_libs/torch_lib/ops/nn.py | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/onnxscript/function_libs/torch_lib/ops/nn.py b/onnxscript/function_libs/torch_lib/ops/nn.py index 8184fd5eba..1b2ec440bd 100644 --- a/onnxscript/function_libs/torch_lib/ops/nn.py +++ b/onnxscript/function_libs/torch_lib/ops/nn.py @@ -2076,6 +2076,11 @@ def _aten_scaled_dot_product_attention_bool_mask_onnx( op.Add(op.MatMul(query_scaled, key_transposed_scaled), attn_mask), axis=-1, ) + # When using scaled dot product attention with a boolean mask, the softmax operation might return NaN values + # due to the presence of -inf in an entire row (padding tokens), resulting in 0/0 (NaN) in the softmax output. + # This is because there's no safe/masked softmax imp in ONNX, so we need to handle NaN values explicitly to match + # the behavior of PyTorch with boolean masks. + attn_weight = op.Where(op.IsNaN(attn_weight), zero, attn_weight) attn_weight, _ = op.Dropout(attn_weight, dropout_p) return op.MatMul(attn_weight, value)