Commit 194539e
Address NaNs if SDPA is called with all values masked from query (pytorch#157727)
Fixes pytorch#156707
Detect if all values along the softmax axis are infs and overwrite the outputs for those computations with zeros before the final matmul. The behavior should be aligned with the CPU implementation.
These types of cases where all values along the dimension in the attention mask are false leading to the undefined outputs in softmax occur with left padded batches for generation in HF transformers according to the original issue.
Pull Request resolved: pytorch#157727
Approved by: https://github.com/malfet1 parent bcf5063 commit 194539e
File tree
2 files changed
+27
-1
lines changed- aten/src/ATen/native/mps/operations
- test
2 files changed
+27
-1
lines changed| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
114 | 114 | | |
115 | 115 | | |
116 | 116 | | |
| 117 | + | |
| 118 | + | |
| 119 | + | |
| 120 | + | |
| 121 | + | |
| 122 | + | |
| 123 | + | |
| 124 | + | |
| 125 | + | |
117 | 126 | | |
118 | | - | |
| 127 | + | |
| 128 | + | |
| 129 | + | |
| 130 | + | |
| 131 | + | |
| 132 | + | |
119 | 133 | | |
120 | 134 | | |
121 | 135 | | |
| |||
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
9257 | 9257 | | |
9258 | 9258 | | |
9259 | 9259 | | |
| 9260 | + | |
| 9261 | + | |
| 9262 | + | |
| 9263 | + | |
| 9264 | + | |
| 9265 | + | |
| 9266 | + | |
| 9267 | + | |
| 9268 | + | |
| 9269 | + | |
| 9270 | + | |
| 9271 | + | |
9260 | 9272 | | |
9261 | 9273 | | |
9262 | 9274 | | |
| |||
0 commit comments