Commit a21bbbc
[MXFP] mxfp conversions speedup (#8610)
This PR improves the throughput of mxfp8 upcast and downcast operations.
I included a commit from @jongsoo-openai (original PR
[here](triton-lang/triton#8179)) and added
improvements below on top of it. The PR is functionally a no-op, which
is verified by the tests in
``python/triton_kernels/tests/test_mxfp.py``.
Upcast improvements:
- Added native packed e2m1 conversion to fp16 (for Blackwell+).
- Added tensor descriptors to utilize TMA for reading the input mxfp
value tensor and writing the output.
- Note that this addition required adding padding for the innermost
dimension for IO tensors not adhering to tensor descriptor specification
requirements, and unpadding the output afterwards.
- Tuned tile dimensions and num_warps.
Downcast improvements:
- Enabled vectorized store of mxfp4 value tensors (h/t to @ThomasRaoux),
as opposed to byte-level stores.
- Tuned the tile dimensions as well as num_warps.
- Unfortunately, as opposed to upcast, tensor descriptors were unable to
give a consistent performance improvement.
I left performance tuning as a TODO for a subsequent PR.
### Performance comparison (BW, in GBps)
Done via ``python/triton_kernels/tests/test_mxfp.py``.
**Before -- GB200**
```
MXFP8 (e4m3fn):
M N quant_dtype quant_bw_bfloat16 quant_bw_float16 dequant_bw_bfloat16 dequant_bw_float16
---- ---- ------------------- ------------------- ------------------ --------------------- --------------------
1024 8192 torch.float8_e4m3fn 1985.94 2053.35 2154.61 2347.56
4096 8192 torch.float8_e4m3fn 3479.79 3518.71 3243.02 3753.85
MXFP4 (e2m1):
M N quant_dtype quant_bw_bfloat16 quant_bw_float16 dequant_bw_bfloat16 dequant_bw_float16
---- ---- ------------- ------------------- ------------------ --------------------- --------------------
1024 8192 torch.uint8 808.089 815.124 647.589 713.9
4096 8192 torch.uint8 1045.23 1041.91 811.089 888.624
```
**After -- GB200**
```
MXFP8 (e4m3fn):
M N quant_dtype quant_bw_bfloat16 quant_bw_float16 dequant_bw_bfloat16 dequant_bw_float16
---- ---- ------------------- ------------------- ------------------ --------------------- --------------------
1024 8192 torch.float8_e4m3fn 2259.86 2404.99 2119.76 2361.66
4096 8192 torch.float8_e4m3fn 4106.69 4268.29 4038.16 4059
MXFP4 (e2m1):
M N quant_dtype quant_bw_bfloat16 quant_bw_float16 dequant_bw_bfloat16 dequant_bw_float16
---- ---- ------------- ------------------- ------------------ --------------------- --------------------
1024 8192 torch.uint8 1334.75 1332.03 1424.7 1397.36
4096 8192 torch.uint8 2027.41 2028.98 2097.15 2275.56
```
**Before -- H100**
```
MXFP8 (e4m3fn):
M N quant_dtype quant_bw_bfloat16 quant_bw_float16 dequant_bw_bfloat16 dequant_bw_float16
---- ---- ------------------- ------------------- ------------------ --------------------- --------------------
1024 8192 torch.float8_e4m3fn 1250.29 1244.35 1595.2 1588.75
4096 8192 torch.float8_e4m3fn 1805.81 1799.62 2080.51 2118.34
MXFP4 (e2m1):
M N quant_dtype quant_bw_bfloat16 quant_bw_float16 dequant_bw_bfloat16 dequant_bw_float16
---- ---- ------------- ------------------- ------------------ --------------------- --------------------
1024 8192 torch.uint8 418.493 416.102 572.367 627.739
4096 8192 torch.uint8 489.531 490.08 687.861 758.08
```
**After -- H100**
```
MXFP8 (e4m3fn):
M N quant_dtype quant_bw_bfloat16 quant_bw_float16 dequant_bw_bfloat16 dequant_bw_float16
---- ---- ------------------- ------------------- ------------------ --------------------- --------------------
1024 8192 torch.float8_e4m3fn 1604.96 1624.86 1732.23 1751.52
4096 8192 torch.float8_e4m3fn 2347.56 2337.09 2386.74 2292.8
MXFP4 (e2m1):
M N quant_dtype quant_bw_bfloat16 quant_bw_float16 dequant_bw_bfloat16 dequant_bw_float16
---- ---- ------------- ------------------- ------------------ --------------------- --------------------
1024 8192 torch.uint8 731.429 745.575 892.861 917.871
4096 8192 torch.uint8 882.343 894.995 1102.37 1165.08
```
Co-authored-by: jongsoo-openai <[email protected]>1 parent c33b2d9 commit a21bbbc
File tree
4 files changed
+218
-80
lines changed- python/triton_kernels
- tests
- triton_kernels/numerics_details
- mxfp_details
4 files changed
+218
-80
lines changed| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
| 1 | + | |
1 | 2 | | |
2 | 3 | | |
3 | 4 | | |
| |||
23 | 24 | | |
24 | 25 | | |
25 | 26 | | |
26 | | - | |
| 27 | + | |
| 28 | + | |
| 29 | + | |
| 30 | + | |
| 31 | + | |
| 32 | + | |
| 33 | + | |
| 34 | + | |
27 | 35 | | |
28 | 36 | | |
29 | 37 | | |
30 | 38 | | |
31 | 39 | | |
32 | 40 | | |
33 | 41 | | |
34 | | - | |
35 | | - | |
36 | 42 | | |
37 | 43 | | |
38 | | - | |
| 44 | + | |
| 45 | + | |
39 | 46 | | |
40 | 47 | | |
41 | 48 | | |
| |||
153 | 160 | | |
154 | 161 | | |
155 | 162 | | |
| 163 | + | |
156 | 164 | | |
157 | 165 | | |
158 | 166 | | |
| |||
220 | 228 | | |
221 | 229 | | |
222 | 230 | | |
223 | | - | |
224 | | - | |
225 | | - | |
226 | | - | |
227 | | - | |
228 | | - | |
| 231 | + | |
| 232 | + | |
| 233 | + | |
| 234 | + | |
| 235 | + | |
| 236 | + | |
| 237 | + | |
| 238 | + | |
| 239 | + | |
229 | 240 | | |
230 | 241 | | |
231 | | - | |
232 | | - | |
233 | | - | |
234 | | - | |
| 242 | + | |
| 243 | + | |
| 244 | + | |
| 245 | + | |
| 246 | + | |
| 247 | + | |
| 248 | + | |
| 249 | + | |
| 250 | + | |
| 251 | + | |
| 252 | + | |
| 253 | + | |
| 254 | + | |
| 255 | + | |
| 256 | + | |
| 257 | + | |
| 258 | + | |
| 259 | + | |
Lines changed: 79 additions & 26 deletions
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
7 | 7 | | |
8 | 8 | | |
9 | 9 | | |
| 10 | + | |
10 | 11 | | |
11 | 12 | | |
12 | 13 | | |
| |||
20 | 21 | | |
21 | 22 | | |
22 | 23 | | |
23 | | - | |
24 | 24 | | |
25 | 25 | | |
26 | 26 | | |
| |||
44 | 44 | | |
45 | 45 | | |
46 | 46 | | |
47 | | - | |
| 47 | + | |
| 48 | + | |
| 49 | + | |
| 50 | + | |
48 | 51 | | |
49 | 52 | | |
50 | | - | |
| 53 | + | |
51 | 54 | | |
52 | 55 | | |
53 | 56 | | |
54 | | - | |
| 57 | + | |
| 58 | + | |
55 | 59 | | |
56 | 60 | | |
57 | 61 | | |
58 | | - | |
59 | | - | |
60 | | - | |
61 | | - | |
62 | | - | |
63 | | - | |
64 | | - | |
65 | | - | |
66 | | - | |
| 62 | + | |
| 63 | + | |
| 64 | + | |
| 65 | + | |
| 66 | + | |
| 67 | + | |
| 68 | + | |
| 69 | + | |
| 70 | + | |
| 71 | + | |
| 72 | + | |
| 73 | + | |
| 74 | + | |
| 75 | + | |
| 76 | + | |
| 77 | + | |
| 78 | + | |
| 79 | + | |
| 80 | + | |
67 | 81 | | |
68 | 82 | | |
69 | 83 | | |
| |||
89 | 103 | | |
90 | 104 | | |
91 | 105 | | |
92 | | - | |
| 106 | + | |
| 107 | + | |
93 | 108 | | |
94 | 109 | | |
95 | | - | |
| 110 | + | |
96 | 111 | | |
97 | 112 | | |
98 | | - | |
99 | 113 | | |
100 | 114 | | |
101 | | - | |
102 | | - | |
| 115 | + | |
| 116 | + | |
| 117 | + | |
| 118 | + | |
| 119 | + | |
| 120 | + | |
| 121 | + | |
| 122 | + | |
| 123 | + | |
| 124 | + | |
| 125 | + | |
| 126 | + | |
| 127 | + | |
| 128 | + | |
| 129 | + | |
| 130 | + | |
| 131 | + | |
| 132 | + | |
| 133 | + | |
| 134 | + | |
| 135 | + | |
103 | 136 | | |
104 | 137 | | |
105 | | - | |
106 | | - | |
107 | | - | |
108 | | - | |
| 138 | + | |
| 139 | + | |
| 140 | + | |
| 141 | + | |
| 142 | + | |
| 143 | + | |
| 144 | + | |
| 145 | + | |
| 146 | + | |
| 147 | + | |
| 148 | + | |
| 149 | + | |
| 150 | + | |
| 151 | + | |
| 152 | + | |
| 153 | + | |
| 154 | + | |
| 155 | + | |
109 | 156 | | |
110 | 157 | | |
111 | 158 | | |
| |||
218 | 265 | | |
219 | 266 | | |
220 | 267 | | |
221 | | - | |
| 268 | + | |
222 | 269 | | |
223 | 270 | | |
224 | 271 | | |
225 | 272 | | |
226 | | - | |
227 | | - | |
| 273 | + | |
| 274 | + | |
| 275 | + | |
| 276 | + | |
| 277 | + | |
| 278 | + | |
| 279 | + | |
| 280 | + | |
228 | 281 | | |
229 | 282 | | |
230 | 283 | | |
231 | 284 | | |
232 | 285 | | |
233 | | - | |
| 286 | + | |
234 | 287 | | |
235 | 288 | | |
236 | 289 | | |
| |||
Lines changed: 46 additions & 11 deletions
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
1 | 1 | | |
2 | 2 | | |
| 3 | + | |
3 | 4 | | |
4 | 5 | | |
5 | 6 | | |
| |||
72 | 73 | | |
73 | 74 | | |
74 | 75 | | |
| 76 | + | |
| 77 | + | |
| 78 | + | |
| 79 | + | |
| 80 | + | |
| 81 | + | |
| 82 | + | |
| 83 | + | |
| 84 | + | |
| 85 | + | |
| 86 | + | |
| 87 | + | |
| 88 | + | |
| 89 | + | |
| 90 | + | |
| 91 | + | |
| 92 | + | |
| 93 | + | |
| 94 | + | |
| 95 | + | |
| 96 | + | |
| 97 | + | |
75 | 98 | | |
76 | 99 | | |
77 | 100 | | |
78 | 101 | | |
79 | | - | |
| 102 | + | |
80 | 103 | | |
81 | 104 | | |
82 | 105 | | |
83 | 106 | | |
84 | 107 | | |
| 108 | + | |
85 | 109 | | |
86 | | - | |
| 110 | + | |
| 111 | + | |
87 | 112 | | |
88 | 113 | | |
89 | 114 | | |
| |||
93 | 118 | | |
94 | 119 | | |
95 | 120 | | |
96 | | - | |
| 121 | + | |
| 122 | + | |
| 123 | + | |
| 124 | + | |
| 125 | + | |
| 126 | + | |
| 127 | + | |
| 128 | + | |
| 129 | + | |
97 | 130 | | |
98 | 131 | | |
99 | 132 | | |
| |||
105 | 138 | | |
106 | 139 | | |
107 | 140 | | |
108 | | - | |
109 | | - | |
110 | | - | |
111 | | - | |
112 | | - | |
113 | | - | |
| 141 | + | |
| 142 | + | |
| 143 | + | |
| 144 | + | |
| 145 | + | |
| 146 | + | |
| 147 | + | |
| 148 | + | |
114 | 149 | | |
115 | 150 | | |
116 | 151 | | |
| |||
150 | 185 | | |
151 | 186 | | |
152 | 187 | | |
153 | | - | |
| 188 | + | |
154 | 189 | | |
155 | 190 | | |
156 | | - | |
| 191 | + | |
157 | 192 | | |
158 | 193 | | |
159 | 194 | | |
| |||
0 commit comments