You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
[BACKEND] Don't promote fp8 MMAv2 dot inputs for sm120 (#7409)
Fixes #7188
This speeds up fp8 matmuls on consumer blackwell (RTX 50xx series) by
~1.9x on large matmuls.
sm>=89 supports MMAv2 with fp8 operands, but prior to this PR, Triton
was only using this on sm==89; on other architectures, fp8 inputs would
be promoted to fp16 and the mma would be executed in fp16.
This PR causes the the fp8->fp16 promotion step to be skipped on any
architecture >= 89. It also adds more mma variants to support f8
operands and f16 results, which were previously supported via the
`FP16_FP16_FP16_FP16` variant.
Evidence that we should be able to use fp8 operands to mmav2 on any
architecture >= 89: In PTX docs
https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#warp-level-matrix-instructions-mma,
under the "Target ISA Notes" section, see that the e4m3 and e5m2 are
supported on sm_89 or higher (and don't require the "a" suffix, which
would indicate that the support is non-backward-compatible).
Perf improvement verified on a 5070 Ti using 03-matrix-multiplication.py
(below are flops measurements on large MNK sizes):
Before:
```
matmul-performance-fp8:
M N K Triton
...
26 3584.0 3584.0 3584.0 101.256071
27 3712.0 3712.0 3712.0 99.947313
28 3840.0 3840.0 3840.0 101.182062
29 3968.0 3968.0 3968.0 101.771419
30 4096.0 4096.0 4096.0 101.206889
```
After:
```
matmul-performance-fp8:
M N K Triton
...
26 3584.0 3584.0 3584.0 191.309345
27 3712.0 3712.0 3712.0 190.280662
28 3840.0 3840.0 3840.0 195.316740
29 3968.0 3968.0 3968.0 194.305628
30 4096.0 4096.0 4096.0 193.258070
```
0 commit comments