-
Notifications
You must be signed in to change notification settings - Fork 103
Make onnx export SDPA match aten behavior #2479
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -2076,6 +2076,11 @@ def _aten_scaled_dot_product_attention_bool_mask_onnx( | |
| op.Add(op.MatMul(query_scaled, key_transposed_scaled), attn_mask), | ||
| axis=-1, | ||
| ) | ||
| # When using scaled dot product attention with a boolean mask, the softmax operation might return NaN values | ||
| # due to the presence of -inf in an entire row (padding tokens), resulting in 0/0 (NaN) in the softmax output. | ||
| # This is because there's no safe/masked softmax imp in ONNX, so we need to handle NaN values explicitly to match | ||
| # the behavior of PyTorch with boolean masks. | ||
| attn_weight = op.Where(op.IsNaN(attn_weight), zero, attn_weight) | ||
| attn_weight, _ = op.Dropout(attn_weight, dropout_p) | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. @titaiwangms we should probably conditionally skip this line (even though there is a rewrite rule already)
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. OK
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. If you fix this, can you also please add a reference to pytorch/pytorch#103749 in the comments for the previous line fixing NaN?
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. We skip when
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. |
||
| return op.MatMul(attn_weight, value) | ||
|
|
||
|
|
||
Uh oh!
There was an error while loading. Please reload this page.