@@ -54,18 +54,24 @@ def scale_dim0_dim1_reference(
54
54
return x_hp_d0_normalized , x_hp_d1_normalized .t (), amax_dim0 , amax_dim1
55
55
56
56
57
- def to_mx_dim0_reference (x_hp , block_size , scaling_mode = ScaleCalculationMode .FLOOR ):
58
- scale_d0 , data_d0 = to_mx (
59
- x_hp , torch .float8_e4m3fn , block_size , scaling_mode = scaling_mode
60
- )
57
+ def to_mx_dim0_reference (
58
+ x_hp ,
59
+ block_size ,
60
+ scaling_mode = ScaleCalculationMode .FLOOR ,
61
+ target_dtype = torch .float8_e4m3fn ,
62
+ ):
63
+ scale_d0 , data_d0 = to_mx (x_hp , target_dtype , block_size , scaling_mode = scaling_mode )
61
64
return data_d0 , scale_d0
62
65
63
66
64
- def to_mx_dim1_reference (x_hp , block_size , scaling_mode = ScaleCalculationMode .FLOOR ):
67
+ def to_mx_dim1_reference (
68
+ x_hp ,
69
+ block_size ,
70
+ scaling_mode = ScaleCalculationMode .FLOOR ,
71
+ target_dtype = torch .float8_e4m3fn ,
72
+ ):
65
73
x_hp = x_hp .t ().contiguous ()
66
- scale_d1 , data_d1 = to_mx (
67
- x_hp , torch .float8_e4m3fn , block_size , scaling_mode = scaling_mode
68
- )
74
+ scale_d1 , data_d1 = to_mx (x_hp , target_dtype , block_size , scaling_mode = scaling_mode )
69
75
return data_d1 .t (), scale_d1
70
76
71
77
@@ -88,13 +94,14 @@ def run(
88
94
"dim0" ,
89
95
"dim1" ,
90
96
"dim0_dim1" ,
91
- "dim0_mx_floor" ,
92
- "dim0_mx_rceil" ,
93
- "dim1_mx_floor" ,
94
- "dim1_mx_rceil" ,
95
- "dim1_mx_triton_floor" ,
96
- "dim1_mx_cuda_floor" ,
97
- "dim1_mx_cuda_rceil" ,
97
+ "dim0_mxfp8_floor" ,
98
+ "dim0_mxfp4_floor" ,
99
+ "dim0_mxfp8_rceil" ,
100
+ "dim1_mxfp8_floor" ,
101
+ "dim1_mxfp8_rceil" ,
102
+ "dim1_mxfp8_triton_floor" ,
103
+ "dim1_mxfp8_cuda_floor" ,
104
+ "dim1_mxfp8_cuda_rceil" ,
98
105
)
99
106
100
107
x = torch .randn (M , K , dtype = torch .bfloat16 , device = "cuda" ) * 1000
@@ -154,7 +161,7 @@ def run(
154
161
)
155
162
bps = bytes_rw / (time_us / 1e6 )
156
163
157
- elif mode == "dim0_mx_floor " :
164
+ elif mode == "dim0_mxfp8_floor " :
158
165
to_mx_dim0_reference_c = torch .compile (to_mx_dim0_reference )
159
166
y_d0 , s_d0 = to_mx_dim0_reference_c (x , BLOCK_SIZE )
160
167
@@ -172,7 +179,32 @@ def run(
172
179
bytes_w = (y_d0 .numel () + s_d0 .numel ()) * bytes_per_el_fp8
173
180
bps = (bytes_r + bytes_w ) / (time_us / 1e6 )
174
181
175
- elif mode == "dim0_mx_rceil" :
182
+ elif mode == "dim0_mxfp4_floor" :
183
+ to_mx_dim0_reference_c = torch .compile (to_mx_dim0_reference )
184
+ y_d0 , s_d0 = to_mx_dim0_reference_c (
185
+ x , BLOCK_SIZE , target_dtype = torch .float4_e2m1fn_x2
186
+ )
187
+
188
+ for _ in range (2 ):
189
+ __ = to_mx_dim0_reference_c (
190
+ x , BLOCK_SIZE , target_dtype = torch .float4_e2m1fn_x2
191
+ )
192
+ time_us = benchmark_cuda_function_in_microseconds (
193
+ lambda x , b : to_mx_dim0_reference_c (
194
+ x , BLOCK_SIZE , target_dtype = torch .float4_e2m1fn_x2
195
+ ),
196
+ x ,
197
+ BLOCK_SIZE ,
198
+ )
199
+
200
+ # TODO(future PR): make to_mx return float4 directly
201
+ assert y_d0 .dtype == torch .uint8
202
+ assert s_d0 .dtype == torch .float8_e8m0fnu
203
+ bytes_r = x .numel () * bytes_per_el_bf16
204
+ bytes_w = (y_d0 .numel () + s_d0 .numel ()) * bytes_per_el_fp8
205
+ bps = (bytes_r + bytes_w ) / (time_us / 1e6 )
206
+
207
+ elif mode == "dim0_mxfp8_rceil" :
176
208
to_mx_dim0_reference_c = torch .compile (to_mx_dim0_reference )
177
209
y_d0 , s_d0 = to_mx_dim0_reference_c (x , BLOCK_SIZE , ScaleCalculationMode .RCEIL )
178
210
@@ -190,7 +222,7 @@ def run(
190
222
bytes_w = (y_d0 .numel () + s_d0 .numel ()) * bytes_per_el_fp8
191
223
bps = (bytes_r + bytes_w ) / (time_us / 1e6 )
192
224
193
- elif mode == "dim1_mx_floor " :
225
+ elif mode == "dim1_mxfp8_floor " :
194
226
to_mx_dim1_reference_c = torch .compile (to_mx_dim1_reference )
195
227
y_d1 , s_d1 = to_mx_dim1_reference_c (x , BLOCK_SIZE )
196
228
@@ -208,7 +240,7 @@ def run(
208
240
bytes_w = (y_d1 .numel () + s_d1 .numel ()) * bytes_per_el_fp8
209
241
bps = (bytes_r + bytes_w ) / (time_us / 1e6 )
210
242
211
- elif mode == "dim1_mx_rceil " :
243
+ elif mode == "dim1_mxfp8_rceil " :
212
244
to_mx_dim1_reference_c = torch .compile (to_mx_dim1_reference )
213
245
y_d1 , s_d1 = to_mx_dim1_reference_c (x , BLOCK_SIZE , ScaleCalculationMode .RCEIL )
214
246
@@ -226,7 +258,7 @@ def run(
226
258
bytes_w = (y_d1 .numel () + s_d1 .numel ()) * bytes_per_el_fp8
227
259
bps = (bytes_r + bytes_w ) / (time_us / 1e6 )
228
260
229
- elif mode == "dim1_mx_triton_floor " :
261
+ elif mode == "dim1_mxfp8_triton_floor " :
230
262
y_d1 , s_d1 = triton_to_mxfp8_dim1 (x , inner_block_size = BLOCK_SIZE )
231
263
232
264
for _ in range (2 ):
@@ -243,7 +275,7 @@ def run(
243
275
bytes_w = (y_d1 .numel () + s_d1 .numel ()) * bytes_per_el_fp8
244
276
bps = (bytes_r + bytes_w ) / (time_us / 1e6 )
245
277
246
- elif mode == "dim1_mx_cuda_floor " :
278
+ elif mode == "dim1_mxfp8_cuda_floor " :
247
279
from torchao .prototype import mxfp8_cuda
248
280
249
281
_ , y_d1 , _ , s_d1 = mxfp8_cuda .quantize (
@@ -269,7 +301,7 @@ def run(
269
301
bytes_w = (y_d1 .numel () + s_d1 .numel ()) * bytes_per_el_fp8
270
302
bps = (bytes_r + bytes_w ) / (time_us / 1e6 )
271
303
272
- elif mode == "dim1_mx_cuda_rceil " :
304
+ elif mode == "dim1_mxfp8_cuda_rceil " :
273
305
from torchao .prototype import mxfp8_cuda
274
306
275
307
_ , y_d1 , _ , s_d1 = mxfp8_cuda .quantize (
0 commit comments