Commit 8af9f58
Fix Local Attention off by 1 bug (microsoft#25927)
### Description
Previously, local window size of GQA op excluded the current token. This
does not match standard HuggingFace implementations where tokens are
appended and then local masking occurs; the mismatch can cause the mask
to be off by 1 during generation, leading to accuracy issues. This PR
corrects this mismatch by including the current token. In practice, this
effectively decreases GQA window size by 1.
### Motivation and Context
This helps align our models with HuggingFace models.
---------
Co-authored-by: Kunal Vaishnavi <[email protected]>1 parent 978bfca commit 8af9f58
File tree
9 files changed
+13
-16
lines changed- onnxruntime
- contrib_ops
- cpu/bert
- cuda/bert
- cutlass_fmha
- webgpu/bert
- test/python/transformers
9 files changed
+13
-16
lines changed| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
86 | 86 | | |
87 | 87 | | |
88 | 88 | | |
89 | | - | |
| 89 | + | |
90 | 90 | | |
91 | 91 | | |
92 | 92 | | |
| |||
106 | 106 | | |
107 | 107 | | |
108 | 108 | | |
109 | | - | |
| 109 | + | |
110 | 110 | | |
111 | 111 | | |
112 | 112 | | |
| |||
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
297 | 297 | | |
298 | 298 | | |
299 | 299 | | |
300 | | - | |
301 | 300 | | |
302 | | - | |
| 301 | + | |
303 | 302 | | |
304 | | - | |
305 | | - | |
| 303 | + | |
| 304 | + | |
306 | 305 | | |
307 | 306 | | |
308 | 307 | | |
309 | | - | |
| 308 | + | |
310 | 309 | | |
311 | 310 | | |
312 | 311 | | |
| |||
Lines changed: 1 addition & 3 deletions
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
223 | 223 | | |
224 | 224 | | |
225 | 225 | | |
226 | | - | |
227 | | - | |
228 | | - | |
| 226 | + | |
229 | 227 | | |
230 | 228 | | |
231 | 229 | | |
| |||
Lines changed: 1 addition & 1 deletion
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
476 | 476 | | |
477 | 477 | | |
478 | 478 | | |
479 | | - | |
| 479 | + | |
480 | 480 | | |
481 | 481 | | |
482 | 482 | | |
| |||
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
326 | 326 | | |
327 | 327 | | |
328 | 328 | | |
329 | | - | |
| 329 | + | |
330 | 330 | | |
331 | 331 | | |
332 | 332 | | |
| |||
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
250 | 250 | | |
251 | 251 | | |
252 | 252 | | |
253 | | - | |
| 253 | + | |
254 | 254 | | |
255 | 255 | | |
256 | 256 | | |
| |||
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
582 | 582 | | |
583 | 583 | | |
584 | 584 | | |
585 | | - | |
| 585 | + | |
586 | 586 | | |
587 | 587 | | |
588 | 588 | | |
| |||
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
1122 | 1122 | | |
1123 | 1123 | | |
1124 | 1124 | | |
1125 | | - | |
| 1125 | + | |
1126 | 1126 | | |
1127 | 1127 | | |
1128 | 1128 | | |
| |||
Lines changed: 1 addition & 1 deletion
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
331 | 331 | | |
332 | 332 | | |
333 | 333 | | |
334 | | - | |
| 334 | + | |
335 | 335 | | |
336 | 336 | | |
337 | 337 | | |
| |||
0 commit comments