You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
Improve GEMM perf when one matrix is transposed (#2347)
The 2D block load/store does not work when one of the input matrices to
a `tt.dot` is transposed inside the Triton kernel using the `stride`
parameter. In the user example, the block pointer is transposed using
stride but the `order` parameter is left unchanged. This results in
`materialize-block-pointer` being unable to detect that a `block_io`
attribute `column-major` should be added to the matrix. Even if this
attribute were added, `rewrite-tensor-pointer` would remove the block
pointer because column major was not supported.
This PR adds support for detecting `column-major` based on `stride`
instead of `order` and also brings the same logic to
`rewrite-tensor-pointer` to allow for the column major load to be
preserved and eventually lowered to a 2D block load. With this,
transpose matrix performance is more inline with the non-transposed
version:
```
Compute A x B
(I): Detected 7680 spills, recompiling the kernel using large GRF mode
(I): Kernel has now 0 spills
✅ Triton and Torch match
Time for torch: 0.31821921467781067 ms
Time for triton: 0.4404735863208771 ms
Compute A x B.T
(I): Detected 7680 spills, recompiling the kernel using large GRF mode
(I): Kernel has now 0 spills
✅ Triton and Torch match
Time for torch: 0.33270877599716187 ms
Time for triton: 0.6352895498275757 ms
```
Close#1795
0 commit comments