Commit 1fa520e
[ROCm] Enable group gemm through CK (pytorch#166334)
Fixes pytorch#161366
All the 4 types of dimension matrix are supported.
2d-2d, 2d-3d, 3d-3d, 3d-2d. The corresponding test cases in test_matmul_cuda are working
for both forward and backward pass.
The CK path is enabled for gfx942, gfx950.
ToDo: Need to enable support on gfx90a since the ck kernel used in this commit produces gpu error,
might require a different CK kernel config, based on the profiler result on gfx90a.
Pull Request resolved: pytorch#166334
Approved by: https://github.com/jeffdaily, https://github.com/pruthvistony1 parent c2e3cc7 commit 1fa520e
File tree
4 files changed
+487
-2
lines changed- aten/src/ATen/native
- cuda
- hip
- test
4 files changed
+487
-2
lines changed| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
22 | 22 | | |
23 | 23 | | |
24 | 24 | | |
| 25 | + | |
| 26 | + | |
| 27 | + | |
25 | 28 | | |
26 | 29 | | |
27 | 30 | | |
| |||
636 | 639 | | |
637 | 640 | | |
638 | 641 | | |
| 642 | + | |
| 643 | + | |
| 644 | + | |
639 | 645 | | |
640 | 646 | | |
641 | 647 | | |
642 | 648 | | |
643 | 649 | | |
| 650 | + | |
644 | 651 | | |
| 652 | + | |
| 653 | + | |
| 654 | + | |
645 | 655 | | |
646 | 656 | | |
647 | 657 | | |
| |||
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
| 1 | + | |
| 2 | + | |
| 3 | + | |
| 4 | + | |
| 5 | + | |
| 6 | + | |
| 7 | + | |
| 8 | + | |
| 9 | + | |
| 10 | + | |
| 11 | + | |
| 12 | + | |
| 13 | + | |
| 14 | + | |
| 15 | + | |
| 16 | + | |
| 17 | + | |
| 18 | + | |
| 19 | + | |
0 commit comments