Commit 155839b
[pallas:triton] Emit a better error message for matmul with non-2D operands
Triton seems to support both 2D and 3D operands now, the latter case being a
batched matmul. We need more changes in the lowering to support 3D, so I will
leave it out of scope here.
Fixes jax-ml#26013.
PiperOrigin-RevId: 7332932991 parent 8906f28 commit 155839b
File tree
2 files changed
+22
-0
lines changed- jax/_src/pallas/triton
- tests/pallas
2 files changed
+22
-0
lines changed| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
2261 | 2261 | | |
2262 | 2262 | | |
2263 | 2263 | | |
| 2264 | + | |
| 2265 | + | |
| 2266 | + | |
2264 | 2267 | | |
2265 | 2268 | | |
2266 | 2269 | | |
| |||
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
733 | 733 | | |
734 | 734 | | |
735 | 735 | | |
| 736 | + | |
| 737 | + | |
| 738 | + | |
| 739 | + | |
| 740 | + | |
| 741 | + | |
| 742 | + | |
| 743 | + | |
| 744 | + | |
| 745 | + | |
| 746 | + | |
| 747 | + | |
| 748 | + | |
| 749 | + | |
| 750 | + | |
| 751 | + | |
| 752 | + | |
| 753 | + | |
| 754 | + | |
736 | 755 | | |
737 | 756 | | |
738 | 757 | | |
| |||
0 commit comments