Commit ecb7677
authored
Make onnx export SDPA match aten behavior (#2479)
This PR makes onnx sdpa export match the behavior of aten sdpa when
boolean mask is used.
```python
import onnxruntime as ort
import torch
class ScaledDotProductAttention(torch.nn.Module):
def forward(self, query, key, value, attn_mask):
return torch.nn.functional.scaled_dot_product_attention(query, key, value, attn_mask=attn_mask)
model = ScaledDotProductAttention()
attn_mask = torch.ones(2, 4, 8, 8).bool() # boolean mask for attention
attn_mask[0, 0, 0, :] = False # masking an entire row (padding token)
query = key = value = torch.randn(2, 4, 8, 16)
output = model(query, key, value, attn_mask)
torch.onnx.export(
model,
(query, key, value, attn_mask),
"scaled_dot_product_attention.onnx",
input_names=["query", "key", "value", "attn_mask"],
output_names=["output"],
opset_version=18,
dynamo=True, # or False
)
ort_session = ort.InferenceSession("scaled_dot_product_attention.onnx")
np_inputs = {"query": query.numpy(), "key": key.numpy(), "value": value.numpy(), "attn_mask": attn_mask.numpy()}
onnx_outputs = ort_session.run(None, np_inputs)[0]
torch.testing.assert_close(output, torch.tensor(onnx_outputs), equal_nan=True)
```
fails the assertion because the ort model outputs nans.1 parent 32f2196 commit ecb7677
1 file changed
+5
-0
lines changed| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
2076 | 2076 | | |
2077 | 2077 | | |
2078 | 2078 | | |
| 2079 | + | |
| 2080 | + | |
| 2081 | + | |
| 2082 | + | |
| 2083 | + | |
2079 | 2084 | | |
2080 | 2085 | | |
2081 | 2086 | | |
| |||
0 commit comments