Commit 77a7b4c
[Blackwell] Enable MMA pipelining for scaled dot when TMEM copy is used (triton-lang#5812)
This PR enables MMA pipelining for scaled dot.
The main difficulty this PR overcomes is the dependency cycle between
TMEM copy rewriting and SWP - currently TMEM copy rewriting relies on
SWP to put loading of scales into SMEM, while to apply MMA pipelining
during SWP, TMEM copy rewriting needs to have happened beforehand. I
propose to break the cycle by having loading of scales go through
`local_alloc` and `local_load` in `AccelerateMatmul`. This way, TMEM
copy rewriting happens during [the first call to
OptimizedDotOperands,](https://github.com/triton-lang/triton/blob/1e0e51c4aeb3e1beea000da5d0e494f8b9ac40dd/third_party/nvidia/backend/compiler.py#L260)
before SWP. And the local alloc and load added in `AccelerateMatmul` are
eliminated during SWP. It's a bit ad hoc to add local alloc for scales
there, since scales do not need to be in SMEM. But other solutions, like
decoupling MMA pipelining from SWP, is more difficult.
The other changes in this PR are for making SWP recognize loading of
scales when there is TMEM copy between scale load and MMA.
@ThomasRaoux @pawelszczerbuk @csullivan @mbrookhart @binarybana
---------
Co-authored-by: Masahiro Masuda <[email protected]>
Co-authored-by: Jason Knight <[email protected]>1 parent 5d2a1d2 commit 77a7b4c
File tree
8 files changed
+386
-36
lines changed- lib/Dialect/TritonGPU/Transforms
- Pipeliner
- python/test/unit/language
- test/TritonGPU
8 files changed
+386
-36
lines changed| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
184 | 184 | | |
185 | 185 | | |
186 | 186 | | |
| 187 | + | |
| 188 | + | |
| 189 | + | |
| 190 | + | |
| 191 | + | |
| 192 | + | |
| 193 | + | |
| 194 | + | |
| 195 | + | |
| 196 | + | |
| 197 | + | |
| 198 | + | |
| 199 | + | |
| 200 | + | |
| 201 | + | |
| 202 | + | |
| 203 | + | |
| 204 | + | |
| 205 | + | |
187 | 206 | | |
188 | 207 | | |
189 | 208 | | |
| |||
575 | 594 | | |
576 | 595 | | |
577 | 596 | | |
| 597 | + | |
| 598 | + | |
| 599 | + | |
| 600 | + | |
| 601 | + | |
| 602 | + | |
| 603 | + | |
| 604 | + | |
| 605 | + | |
| 606 | + | |
| 607 | + | |
| 608 | + | |
| 609 | + | |
| 610 | + | |
| 611 | + | |
| 612 | + | |
| 613 | + | |
| 614 | + | |
| 615 | + | |
| 616 | + | |
| 617 | + | |
| 618 | + | |
| 619 | + | |
| 620 | + | |
| 621 | + | |
| 622 | + | |
| 623 | + | |
| 624 | + | |
| 625 | + | |
| 626 | + | |
| 627 | + | |
| 628 | + | |
| 629 | + | |
| 630 | + | |
| 631 | + | |
| 632 | + | |
| 633 | + | |
| 634 | + | |
| 635 | + | |
| 636 | + | |
| 637 | + | |
| 638 | + | |
| 639 | + | |
| 640 | + | |
| 641 | + | |
| 642 | + | |
| 643 | + | |
| 644 | + | |
| 645 | + | |
| 646 | + | |
| 647 | + | |
| 648 | + | |
| 649 | + | |
| 650 | + | |
578 | 651 | | |
579 | 652 | | |
580 | 653 | | |
| |||
688 | 761 | | |
689 | 762 | | |
690 | 763 | | |
691 | | - | |
692 | | - | |
693 | | - | |
694 | | - | |
| 764 | + | |
| 765 | + | |
| 766 | + | |
| 767 | + | |
| 768 | + | |
| 769 | + | |
| 770 | + | |
| 771 | + | |
695 | 772 | | |
696 | 773 | | |
697 | 774 | | |
| |||
Lines changed: 10 additions & 0 deletions
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
181 | 181 | | |
182 | 182 | | |
183 | 183 | | |
| 184 | + | |
| 185 | + | |
| 186 | + | |
| 187 | + | |
| 188 | + | |
| 189 | + | |
| 190 | + | |
| 191 | + | |
| 192 | + | |
| 193 | + | |
184 | 194 | | |
185 | 195 | | |
186 | 196 | | |
| |||
Lines changed: 13 additions & 4 deletions
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
177 | 177 | | |
178 | 178 | | |
179 | 179 | | |
180 | | - | |
181 | | - | |
182 | 180 | | |
183 | 181 | | |
184 | 182 | | |
185 | 183 | | |
186 | 184 | | |
187 | | - | |
188 | | - | |
| 185 | + | |
| 186 | + | |
| 187 | + | |
| 188 | + | |
| 189 | + | |
| 190 | + | |
| 191 | + | |
189 | 192 | | |
190 | 193 | | |
191 | 194 | | |
| |||
455 | 458 | | |
456 | 459 | | |
457 | 460 | | |
| 461 | + | |
| 462 | + | |
| 463 | + | |
| 464 | + | |
| 465 | + | |
| 466 | + | |
458 | 467 | | |
459 | 468 | | |
460 | 469 | | |
| |||
Lines changed: 36 additions & 3 deletions
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
593 | 593 | | |
594 | 594 | | |
595 | 595 | | |
| 596 | + | |
| 597 | + | |
| 598 | + | |
| 599 | + | |
| 600 | + | |
| 601 | + | |
| 602 | + | |
| 603 | + | |
| 604 | + | |
| 605 | + | |
| 606 | + | |
| 607 | + | |
| 608 | + | |
| 609 | + | |
| 610 | + | |
| 611 | + | |
| 612 | + | |
| 613 | + | |
| 614 | + | |
| 615 | + | |
| 616 | + | |
| 617 | + | |
| 618 | + | |
| 619 | + | |
| 620 | + | |
596 | 621 | | |
597 | 622 | | |
598 | 623 | | |
| |||
603 | 628 | | |
604 | 629 | | |
605 | 630 | | |
606 | | - | |
607 | | - | |
608 | | - | |
| 631 | + | |
| 632 | + | |
| 633 | + | |
| 634 | + | |
| 635 | + | |
| 636 | + | |
| 637 | + | |
| 638 | + | |
| 639 | + | |
| 640 | + | |
| 641 | + | |
609 | 642 | | |
610 | 643 | | |
611 | 644 | | |
| |||
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
352 | 352 | | |
353 | 353 | | |
354 | 354 | | |
355 | | - | |
356 | | - | |
357 | | - | |
358 | | - | |
359 | | - | |
360 | | - | |
| 355 | + | |
| 356 | + | |
| 357 | + | |
361 | 358 | | |
362 | 359 | | |
363 | 360 | | |
| |||
437 | 434 | | |
438 | 435 | | |
439 | 436 | | |
| 437 | + | |
| 438 | + | |
| 439 | + | |
| 440 | + | |
440 | 441 | | |
441 | 442 | | |
442 | 443 | | |
443 | 444 | | |
444 | 445 | | |
445 | 446 | | |
446 | | - | |
| 447 | + | |
| 448 | + | |
| 449 | + | |
| 450 | + | |
| 451 | + | |
447 | 452 | | |
448 | 453 | | |
449 | 454 | | |
| |||
467 | 472 | | |
468 | 473 | | |
469 | 474 | | |
| 475 | + | |
470 | 476 | | |
471 | 477 | | |
472 | 478 | | |
| |||
488 | 494 | | |
489 | 495 | | |
490 | 496 | | |
491 | | - | |
492 | | - | |
| 497 | + | |
| 498 | + | |
| 499 | + | |
| 500 | + | |
493 | 501 | | |
| 502 | + | |
494 | 503 | | |
495 | 504 | | |
496 | 505 | | |
497 | 506 | | |
498 | 507 | | |
499 | 508 | | |
500 | | - | |
501 | | - | |
502 | | - | |
503 | | - | |
504 | | - | |
| 509 | + | |
| 510 | + | |
| 511 | + | |
| 512 | + | |
505 | 513 | | |
506 | 514 | | |
507 | 515 | | |
508 | 516 | | |
509 | 517 | | |
510 | | - | |
511 | | - | |
512 | | - | |
513 | | - | |
514 | | - | |
515 | 518 | | |
516 | 519 | | |
517 | 520 | | |
| |||
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
302 | 302 | | |
303 | 303 | | |
304 | 304 | | |
| 305 | + | |
| 306 | + | |
| 307 | + | |
| 308 | + | |
305 | 309 | | |
306 | 310 | | |
307 | 311 | | |
308 | 312 | | |
309 | | - | |
| 313 | + | |
310 | 314 | | |
| 315 | + | |
| 316 | + | |
311 | 317 | | |
312 | 318 | | |
313 | 319 | | |
| |||
389 | 395 | | |
390 | 396 | | |
391 | 397 | | |
| 398 | + | |
| 399 | + | |
| 400 | + | |
| 401 | + | |
| 402 | + | |
| 403 | + | |
| 404 | + | |
| 405 | + | |
| 406 | + | |
| 407 | + | |
| 408 | + | |
| 409 | + | |
| 410 | + | |
| 411 | + | |
| 412 | + | |
| 413 | + | |
| 414 | + | |
| 415 | + | |
| 416 | + | |
| 417 | + | |
| 418 | + | |
| 419 | + | |
| 420 | + | |
| 421 | + | |
| 422 | + | |
| 423 | + | |
| 424 | + | |
| 425 | + | |
| 426 | + | |
| 427 | + | |
| 428 | + | |
| 429 | + | |
| 430 | + | |
| 431 | + | |
| 432 | + | |
| 433 | + | |
0 commit comments