Commit b801539
[Pallas][Mosaic GPU] Add support for compressing squeezed dims in async_copy + grid fixes
This change removes the need to flatten the batch dimension into sequence dimensions
in the flash attention kernel. The critical thing here is the observation that we can
in fact collapse all squeezed dimension into a single one in the TMA descriptor, letting
us reduce its rank when necessary.
Doing this also uncovered some issues with how we were handling the grid in Pallas:MGPU
lowering, which I've fixed.
PiperOrigin-RevId: 7010352771 parent d5bfafb commit b801539
File tree
4 files changed
+110
-44
lines changed- jax
- _src/pallas/mosaic_gpu
- experimental
- mosaic/gpu
- pallas/ops/gpu
- tests/mosaic
4 files changed
+110
-44
lines changed| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
360 | 360 | | |
361 | 361 | | |
362 | 362 | | |
363 | | - | |
364 | | - | |
365 | | - | |
366 | | - | |
367 | 363 | | |
368 | 364 | | |
369 | 365 | | |
| |||
397 | 393 | | |
398 | 394 | | |
399 | 395 | | |
400 | | - | |
401 | | - | |
402 | 396 | | |
403 | 397 | | |
404 | | - | |
405 | | - | |
406 | | - | |
407 | | - | |
408 | | - | |
| 398 | + | |
409 | 399 | | |
| 400 | + | |
| 401 | + | |
| 402 | + | |
| 403 | + | |
| 404 | + | |
| 405 | + | |
| 406 | + | |
| 407 | + | |
| 408 | + | |
410 | 409 | | |
411 | 410 | | |
412 | 411 | | |
| |||
500 | 499 | | |
501 | 500 | | |
502 | 501 | | |
503 | | - | |
| 502 | + | |
504 | 503 | | |
505 | 504 | | |
506 | 505 | | |
| |||
788 | 787 | | |
789 | 788 | | |
790 | 789 | | |
791 | | - | |
| 790 | + | |
792 | 791 | | |
793 | 792 | | |
794 | 793 | | |
| |||
806 | 805 | | |
807 | 806 | | |
808 | 807 | | |
809 | | - | |
| 808 | + | |
| 809 | + | |
| 810 | + | |
810 | 811 | | |
811 | 812 | | |
812 | 813 | | |
| |||
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
234 | 234 | | |
235 | 235 | | |
236 | 236 | | |
| 237 | + | |
| 238 | + | |
| 239 | + | |
| 240 | + | |
| 241 | + | |
| 242 | + | |
| 243 | + | |
| 244 | + | |
| 245 | + | |
| 246 | + | |
| 247 | + | |
| 248 | + | |
| 249 | + | |
| 250 | + | |
| 251 | + | |
| 252 | + | |
| 253 | + | |
| 254 | + | |
| 255 | + | |
| 256 | + | |
| 257 | + | |
| 258 | + | |
| 259 | + | |
| 260 | + | |
| 261 | + | |
| 262 | + | |
| 263 | + | |
| 264 | + | |
| 265 | + | |
| 266 | + | |
| 267 | + | |
| 268 | + | |
| 269 | + | |
| 270 | + | |
| 271 | + | |
| 272 | + | |
| 273 | + | |
| 274 | + | |
| 275 | + | |
| 276 | + | |
| 277 | + | |
| 278 | + | |
| 279 | + | |
| 280 | + | |
| 281 | + | |
| 282 | + | |
| 283 | + | |
| 284 | + | |
| 285 | + | |
| 286 | + | |
| 287 | + | |
237 | 288 | | |
238 | 289 | | |
239 | 290 | | |
| |||
397 | 448 | | |
398 | 449 | | |
399 | 450 | | |
| 451 | + | |
| 452 | + | |
| 453 | + | |
| 454 | + | |
| 455 | + | |
| 456 | + | |
| 457 | + | |
| 458 | + | |
| 459 | + | |
| 460 | + | |
| 461 | + | |
400 | 462 | | |
401 | 463 | | |
402 | 464 | | |
| |||
421 | 483 | | |
422 | 484 | | |
423 | 485 | | |
| 486 | + | |
| 487 | + | |
| 488 | + | |
| 489 | + | |
| 490 | + | |
| 491 | + | |
| 492 | + | |
| 493 | + | |
| 494 | + | |
| 495 | + | |
| 496 | + | |
| 497 | + | |
| 498 | + | |
| 499 | + | |
| 500 | + | |
| 501 | + | |
424 | 502 | | |
425 | 503 | | |
426 | | - | |
| 504 | + | |
427 | 505 | | |
428 | 506 | | |
429 | 507 | | |
| |||
437 | 515 | | |
438 | 516 | | |
439 | 517 | | |
440 | | - | |
| 518 | + | |
441 | 519 | | |
442 | 520 | | |
443 | 521 | | |
| |||
446 | 524 | | |
447 | 525 | | |
448 | 526 | | |
449 | | - | |
| 527 | + | |
450 | 528 | | |
451 | 529 | | |
452 | 530 | | |
453 | 531 | | |
454 | 532 | | |
455 | 533 | | |
456 | | - | |
| 534 | + | |
457 | 535 | | |
458 | 536 | | |
459 | 537 | | |
| |||
508 | 586 | | |
509 | 587 | | |
510 | 588 | | |
511 | | - | |
512 | | - | |
513 | | - | |
514 | 589 | | |
515 | 590 | | |
516 | 591 | | |
| |||
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
61 | 61 | | |
62 | 62 | | |
63 | 63 | | |
64 | | - | |
65 | | - | |
66 | | - | |
67 | | - | |
68 | | - | |
69 | | - | |
70 | | - | |
71 | 64 | | |
72 | 65 | | |
73 | 66 | | |
74 | 67 | | |
75 | 68 | | |
76 | | - | |
77 | | - | |
78 | | - | |
79 | 69 | | |
80 | 70 | | |
81 | | - | |
82 | | - | |
| 71 | + | |
83 | 72 | | |
84 | 73 | | |
85 | 74 | | |
| |||
93 | 82 | | |
94 | 83 | | |
95 | 84 | | |
96 | | - | |
| 85 | + | |
97 | 86 | | |
98 | 87 | | |
99 | 88 | | |
100 | | - | |
| 89 | + | |
101 | 90 | | |
102 | 91 | | |
103 | 92 | | |
| |||
167 | 156 | | |
168 | 157 | | |
169 | 158 | | |
170 | | - | |
| 159 | + | |
171 | 160 | | |
172 | 161 | | |
173 | 162 | | |
174 | 163 | | |
175 | 164 | | |
176 | 165 | | |
177 | 166 | | |
178 | | - | |
179 | | - | |
| 167 | + | |
180 | 168 | | |
181 | 169 | | |
182 | 170 | | |
183 | 171 | | |
184 | 172 | | |
185 | 173 | | |
186 | | - | |
187 | | - | |
| 174 | + | |
188 | 175 | | |
189 | 176 | | |
190 | 177 | | |
| |||
199 | 186 | | |
200 | 187 | | |
201 | 188 | | |
| 189 | + | |
| 190 | + | |
| 191 | + | |
202 | 192 | | |
203 | | - | |
| 193 | + | |
204 | 194 | | |
205 | | - | |
| 195 | + | |
206 | 196 | | |
207 | 197 | | |
208 | 198 | | |
| |||
236 | 226 | | |
237 | 227 | | |
238 | 228 | | |
239 | | - | |
| 229 | + | |
240 | 230 | | |
241 | 231 | | |
242 | 232 | | |
| |||
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
1240 | 1240 | | |
1241 | 1241 | | |
1242 | 1242 | | |
1243 | | - | |
| 1243 | + | |
1244 | 1244 | | |
1245 | 1245 | | |
1246 | 1246 | | |
| |||
0 commit comments