Commit a5feacb
[SDPA] [MPS] Fixes regression in 2.8.0 for scaled_dot_product_attention using mps (pytorch#164364)
[SDPA] [MPS] Fixes regression in 2.8.0 for scaled_dot_product_attention using mps (pytorch#163598)
Fixes pytorch#163597
- Updates fast SDPA implementations to take in query tensor stride info similar to key and value instead of assuming stride.
- Updated tests with additional transpose/permutation layouts. New tests catch the regression.
### Benchmarking with script found in [implementation PR](pytorch#152781)
Times are averaged over 100000 iterations. This change should not have any significant performance difference. Tested on an M3 Pro
### Vector Fast Path (q_len=1, k_len=256)
- Before: 0.160 ms
- After: 0.157 ms
### Vector 2-pass (q_len=1, k_len=4096)
- Before: 0.342 ms
- After: 0.339 ms
### Vector Fast Path (q_len=8, k_len=256)
- Before: 0.228 ms
- After: 0.231 ms
### Vector 2-pass (q_len=8, k_len=4096)
- Before: 0.432 ms
- After: 0.436 ms
Pull Request resolved: pytorch#163598
Approved by: https://github.com/malfet
(cherry picked from commit 1c12d74)
Co-authored-by: Vismai Khanderao <[email protected]>1 parent 71282c8 commit a5feacb
File tree
3 files changed
+82
-55
lines changed- aten/src/ATen/native/mps
- kernels
- operations
- test
3 files changed
+82
-55
lines changed| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
14 | 14 | | |
15 | 15 | | |
16 | 16 | | |
17 | | - | |
18 | | - | |
| 17 | + | |
| 18 | + | |
19 | 19 | | |
20 | 20 | | |
21 | 21 | | |
| |||
28 | 28 | | |
29 | 29 | | |
30 | 30 | | |
31 | | - | |
32 | | - | |
33 | | - | |
34 | | - | |
| 31 | + | |
| 32 | + | |
| 33 | + | |
| 34 | + | |
| 35 | + | |
| 36 | + | |
35 | 37 | | |
36 | 38 | | |
37 | 39 | | |
| |||
54 | 56 | | |
55 | 57 | | |
56 | 58 | | |
57 | | - | |
58 | 59 | | |
59 | | - | |
| 60 | + | |
| 61 | + | |
60 | 62 | | |
61 | 63 | | |
62 | 64 | | |
| |||
156 | 158 | | |
157 | 159 | | |
158 | 160 | | |
159 | | - | |
160 | | - | |
| 161 | + | |
| 162 | + | |
161 | 163 | | |
162 | 164 | | |
163 | 165 | | |
| |||
170 | 172 | | |
171 | 173 | | |
172 | 174 | | |
173 | | - | |
174 | | - | |
175 | | - | |
176 | | - | |
| 175 | + | |
| 176 | + | |
| 177 | + | |
| 178 | + | |
| 179 | + | |
| 180 | + | |
177 | 181 | | |
178 | 182 | | |
179 | 183 | | |
| |||
196 | 200 | | |
197 | 201 | | |
198 | 202 | | |
199 | | - | |
200 | 203 | | |
201 | 204 | | |
202 | | - | |
| 205 | + | |
| 206 | + | |
203 | 207 | | |
204 | 208 | | |
205 | 209 | | |
| |||
520 | 524 | | |
521 | 525 | | |
522 | 526 | | |
523 | | - | |
524 | | - | |
525 | | - | |
526 | | - | |
527 | | - | |
528 | | - | |
529 | | - | |
530 | | - | |
531 | | - | |
532 | | - | |
533 | | - | |
534 | | - | |
535 | | - | |
536 | | - | |
537 | | - | |
538 | | - | |
539 | | - | |
540 | | - | |
541 | | - | |
| 527 | + | |
| 528 | + | |
| 529 | + | |
| 530 | + | |
| 531 | + | |
| 532 | + | |
| 533 | + | |
| 534 | + | |
| 535 | + | |
| 536 | + | |
| 537 | + | |
| 538 | + | |
| 539 | + | |
| 540 | + | |
| 541 | + | |
| 542 | + | |
| 543 | + | |
| 544 | + | |
| 545 | + | |
542 | 546 | | |
543 | 547 | | |
544 | 548 | | |
| |||
553 | 557 | | |
554 | 558 | | |
555 | 559 | | |
556 | | - | |
557 | | - | |
| 560 | + | |
| 561 | + | |
558 | 562 | | |
559 | 563 | | |
560 | 564 | | |
| |||
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
182 | 182 | | |
183 | 183 | | |
184 | 184 | | |
| 185 | + | |
| 186 | + | |
185 | 187 | | |
186 | 188 | | |
187 | 189 | | |
| |||
209 | 211 | | |
210 | 212 | | |
211 | 213 | | |
212 | | - | |
213 | | - | |
| 214 | + | |
| 215 | + | |
214 | 216 | | |
215 | 217 | | |
216 | 218 | | |
| |||
257 | 259 | | |
258 | 260 | | |
259 | 261 | | |
| 262 | + | |
| 263 | + | |
260 | 264 | | |
261 | 265 | | |
262 | 266 | | |
| |||
294 | 298 | | |
295 | 299 | | |
296 | 300 | | |
297 | | - | |
298 | | - | |
| 301 | + | |
| 302 | + | |
299 | 303 | | |
300 | 304 | | |
301 | 305 | | |
| |||
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
9472 | 9472 | | |
9473 | 9473 | | |
9474 | 9474 | | |
9475 | | - | |
9476 | | - | |
| 9475 | + | |
| 9476 | + | |
9477 | 9477 | | |
9478 | 9478 | | |
9479 | | - | |
| 9479 | + | |
| 9480 | + | |
9480 | 9481 | | |
9481 | 9482 | | |
| 9483 | + | |
| 9484 | + | |
| 9485 | + | |
| 9486 | + | |
| 9487 | + | |
| 9488 | + | |
| 9489 | + | |
| 9490 | + | |
| 9491 | + | |
| 9492 | + | |
| 9493 | + | |
9482 | 9494 | | |
9483 | 9495 | | |
9484 | 9496 | | |
9485 | | - | |
| 9497 | + | |
| 9498 | + | |
| 9499 | + | |
| 9500 | + | |
| 9501 | + | |
| 9502 | + | |
| 9503 | + | |
| 9504 | + | |
| 9505 | + | |
9486 | 9506 | | |
9487 | 9507 | | |
9488 | 9508 | | |
| |||
9523 | 9543 | | |
9524 | 9544 | | |
9525 | 9545 | | |
9526 | | - | |
| 9546 | + | |
9527 | 9547 | | |
9528 | 9548 | | |
9529 | | - | |
| 9549 | + | |
9530 | 9550 | | |
9531 | 9551 | | |
9532 | 9552 | | |
9533 | 9553 | | |
9534 | 9554 | | |
9535 | | - | |
| 9555 | + | |
9536 | 9556 | | |
9537 | 9557 | | |
9538 | 9558 | | |
9539 | | - | |
| 9559 | + | |
9540 | 9560 | | |
9541 | | - | |
| 9561 | + | |
9542 | 9562 | | |
9543 | 9563 | | |
9544 | 9564 | | |
9545 | 9565 | | |
9546 | 9566 | | |
9547 | 9567 | | |
9548 | | - | |
| 9568 | + | |
9549 | 9569 | | |
9550 | 9570 | | |
9551 | 9571 | | |
9552 | 9572 | | |
9553 | | - | |
| 9573 | + | |
9554 | 9574 | | |
9555 | 9575 | | |
9556 | | - | |
| 9576 | + | |
9557 | 9577 | | |
9558 | 9578 | | |
9559 | 9579 | | |
9560 | 9580 | | |
9561 | 9581 | | |
9562 | | - | |
| 9582 | + | |
9563 | 9583 | | |
9564 | 9584 | | |
9565 | 9585 | | |
9566 | 9586 | | |
9567 | | - | |
9568 | 9587 | | |
9569 | 9588 | | |
9570 | 9589 | | |
| |||
0 commit comments