You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
[AMD] Add tilesPerWarp parameter to mfma layout (#7283)
This PR introduces the tilesPerWarp parameter to the MFMA layout.
Previously, the MFMA layout assumed that each warp within a CTA tile
computed a single MFMA tile.
When the tensor was larger than a single CTA tile, these tiles were
repeated across the tensor.
In this setup, the output tiles computed by each wave were strided by
the number of warps
per CTA in both row and column dimensions.
For instance, with 16 MFMA tiles and warpsPerCTA = [2, 2], the
distribution of
warps across the MFMA tiles looked like:
w0 w1 w0 w1
w2 w3 w2 w3
w0 w1 w0 w1
w2 w3 w2 w3
The new tilesPerWarp parameter allows each warp to compute contiguous
MFMA tiles
in the row and/or column dimensions. Using the same example with
tilesPerWarp = [2, 2], the layout becomes:
w0 w0 w1 w1
w0 w0 w1 w1
w2 w2 w3 w3
w2 w2 w3 w3
While this is a general enhancement, the main motivation for introducing
this parameter
is to improve memory access efficiency for scale tensors in scaled dot
operations.
Specific patterns and use cases will be implemented in follow-up PRs.
---------
Co-authored-by: Ognjen Plavsic <[email protected]>
Co-authored-by: Lei Zhang <[email protected]>
0 commit comments