1
- from dataclasses import dataclass , fields
1
+ # isort: off
2
+ # fmt: off
3
+ from dataclasses import dataclass , fields , replace
2
4
import pytest
3
5
import torch
4
6
from typing import Union
7
9
from triton_kernels .routing import routing
8
10
# matmul utilities
9
11
import triton_kernels .matmul_ogs_details .opt_flags as opt_flags
10
- from triton_kernels .matmul_ogs import FlexCtx , PrecisionConfig , FusedActivation , FnSpecs
12
+ from triton_kernels .matmul_ogs import FlexCtx , PrecisionConfig , FusedActivation , FnSpecs , FnName , Epilogue
11
13
from triton_kernels .matmul_ogs import matmul_ogs_set_idle_sms , matmul_ogs , matmul_ogs_torch
12
14
from triton_kernels .swiglu import swiglu , swiglu_fn , PrecisionConfig as SwiGLUPrecisionConfig
13
15
from triton_kernels .tensor import convert_layout , wrap_torch_tensor , FP4
14
16
from triton_kernels .tensor_details import layout
15
17
# numerics utilities
16
18
from triton_kernels .numerics import InFlexData , OutFlexData
17
- from triton_kernels .numerics_details .mxfp import downcast_to_mxfp , upcast_from_mxfp
19
+ from triton_kernels .numerics_details .mxfp import downcast_to_mxfp , upcast_from_mxfp , dequantize_mxfp8_fn , downcast_to_mxfp_torch , upcast_from_mxfp_torch , MXFP_BLOCK_SIZE
18
20
# testing utilities
19
21
from triton_kernels .testing import assert_close , compute_actual_scale
20
22
# target-specific utilities
@@ -78,9 +80,8 @@ def init_compute_data(m, n, k, gindx, sindx, n_expts_tot, n_expts_act, n_expt_sh
78
80
# ---------------
79
81
80
82
81
- def init_precision (out_dtype , weight_dtype , is_mixed_input , n_expts_tot = 1 , device = "cuda" ):
82
- act_use_flexpoint = out_dtype .itemsize == 1
83
- weight_use_flexpoint = weight_dtype .itemsize == 1 and not is_mixed_input
83
+ def init_precision (out_dtype , act_use_flexpoint , weight_dtype , weight_mxfp , n_expts_tot = 1 , device = "cuda" ):
84
+ weight_use_flexpoint = weight_dtype .itemsize == 1 and not weight_mxfp
84
85
# flexpoint
85
86
make_tensor = lambda val0 , val1 : torch .tensor ([val0 , val1 ] * (n_expts_tot // 2 ) +
86
87
([val0 ]
@@ -106,13 +107,14 @@ def apply_precision(x_tri, w_tri, bias_tri, gs0_tri, gs1_tri, precision_config):
106
107
107
108
def apply (x , scale ):
108
109
if scale is None :
109
- return x .clone (). detach (). requires_grad_ ( True )
110
+ x = x .clone ()
110
111
elif scale .numel () == 1 :
111
- return ( x .float () * scale ). detach (). requires_grad_ ( True )
112
+ x = x .float () * scale
112
113
else :
113
114
assert x .ndim == 3
114
115
assert scale .numel () == x .shape [0 ]
115
- return (x .float () * scale [:, None , None ]).detach ().requires_grad_ (True )
116
+ x = x .float () * scale [:, None , None ]
117
+ return x .detach ().requires_grad_ ()
116
118
117
119
return (
118
120
apply (x_tri , flex_ctx .lhs_data .scale ),
@@ -215,6 +217,19 @@ class Case:
215
217
Case (300 , 400 , 400 , "batched" , "float8_e5m2" , "mxfloat8_e4m3fn" , 32 , 4 , hbm_swizzling = True ),
216
218
Case (256 , 256 , 256 , "ragged" , "float8_e5m2" , "mxfloat4_e2m1" , 128 , 4 , hbm_swizzling = True ),
217
219
Case (256 , 256 , 256 , "ragged" , "float8_e5m2" , "mxfloat4_e2m1" , 128 , 4 , hbm_swizzling = False ),
220
+ Case (16 , 256 , 256 , "ragged" , "mxfloat8_e4m3fn" , "mxfloat4_e2m1" , 128 , 4 , hbm_swizzling = True ),
221
+ Case (1000 , 704 , 800 , "batched" , "mxfloat8_e4m3fn" , "mxfloat4_e2m1" , 3 , 1 , hbm_swizzling = True ),
222
+ Case (1000 , 704 , 800 , "batched" , "mxfloat8_e4m3fn" , "mxfloat4_e2m1" , 2 , 1 ),
223
+ Case (1000 , 704 , 800 , "ragged" , "mxfloat8_e4m3fn" , "mxfloat4_e2m1" , 8 , 2 , split_k = 9 ),
224
+ Case (1000 , 704 , 800 , "ragged" , "mxfloat8_e4m3fn" , "mxfloat4_e2m1" , 8 , 2 , split_k = 9 , hbm_swizzling = True ),
225
+ Case (1000 , 704 , 800 , "ragged" , "mxfloat8_e4m3fn" , "mxfloat4_e2m1" , 8 , 2 ),
226
+ Case (1000 , 704 , 800 , "ragged" , "mxfloat8_e4m3fn" , "mxfloat4_e2m1" , 8 , 2 , hbm_swizzling = True ),
227
+ Case (300 , 400 , 400 , "ragged" , "mxfloat8_e4m3fn" , "mxfloat8_e4m3fn" , 8 , 4 ),
228
+ Case (300 , 400 , 400 , "ragged" , "mxfloat8_e4m3fn" , "mxfloat8_e4m3fn" , 8 , 4 , hbm_swizzling = True ),
229
+ Case (300 , 400 , 800 , "ragged" , "mxfloat8_e4m3fn" , "mxfloat4_e2m1" , 8 , 4 ),
230
+ Case (300 , 400 , 800 , "ragged" , "mxfloat8_e4m3fn" , "mxfloat4_e2m1" , 8 , 4 , hbm_swizzling = True ),
231
+ Case (300 , 400 , 400 , "batched" , "mxfloat8_e4m3fn" , "mxfloat8_e4m3fn" , 32 , 4 ),
232
+ Case (300 , 400 , 400 , "batched" , "mxfloat8_e4m3fn" , "mxfloat8_e4m3fn" , 32 , 4 , hbm_swizzling = True ),
218
233
# AMD
219
234
Case (300 , 400 , 400 , "ragged" , "float8_e4m3fnuz" , "float8_e4m3fnuz" ),
220
235
Case (1000 , 400 , 400 , "ragged" , "float8_e4m3fnuz" , "float8_e4m3fnuz" , 3 , 1 ),
@@ -247,8 +262,12 @@ def test_op(m, n, k, split_k, do_gather, do_scatter, fused_scatter, has_y_gammas
247
262
pytest .skip ("Float8 not tested on A100" )
248
263
if "float16" in act_dtype_str and "mx" in weight_dtype_str and torch .cuda .get_device_capability ()[0 ] >= 10 :
249
264
pytest .skip ("float16 x mx not supported with cuda capability >= 10" )
250
- if "float8" in act_dtype_str and "mx" in weight_dtype_str and torch .cuda .get_device_capability ()[0 ] < 10 :
251
- pytest .skip ("float8 x mx not supported with cuda capability < 10" )
265
+ if weight_dtype_str .startswith ("mx" ):
266
+ if "float8" in act_dtype_str and torch .cuda .get_device_capability ()[0 ] < 10 :
267
+ pytest .skip ("float8 x mx not supported with cuda capability < 10" )
268
+ if act_dtype_str == "mxfloat8_e4m3fn" :
269
+ if is_persistent :
270
+ pytest .skip ("mx x mx not supported with persistent kernel" )
252
271
if n == 2880 and k == 2880 and torch .cuda .get_device_capability ()[0 ] < 9 :
253
272
pytest .skip ("Not enough memory on A100" )
254
273
@@ -257,6 +276,8 @@ def test_op(m, n, k, split_k, do_gather, do_scatter, fused_scatter, has_y_gammas
257
276
pytest .skip ("float8 x mx only supported on CDNA4" )
258
277
if "float8" in act_dtype_str and "mxfloat8" in weight_dtype_str :
259
278
pytest .skip ("NYI: float8 x mxfloat8 not tested on AMD GPU" )
279
+ if act_dtype_str .startswith ("mx" ) and weight_dtype_str .startswith ("mx" ):
280
+ pytest .skip ("NYI: mx x mx not tested on AMD GPU" )
260
281
if is_persistent :
261
282
pytest .skip ("NYI: Persistent kernel not supported on AMD GPU" )
262
283
if split_k > 1 :
@@ -301,24 +322,30 @@ def test_op(m, n, k, split_k, do_gather, do_scatter, fused_scatter, has_y_gammas
301
322
}
302
323
opt_flags .update_opt_flags_constraints (constraints )
303
324
304
- is_mixed_input = act_dtype_str != weight_dtype_str
305
- if weight_dtype_str . startswith ( "mx" ) :
325
+ weight_mxfp = weight_dtype_str . startswith ( "mx" )
326
+ if weight_mxfp :
306
327
weight_dtype_str = weight_dtype_str [2 :]
328
+ act_mxfp8 = act_dtype_str .startswith ("mx" )
329
+ act_is_float8 = act_dtype_str .startswith ("float8" )
330
+ if act_mxfp8 :
331
+ act_dtype_str = act_dtype_str [2 :]
332
+ dequantize_mxfp8_spec = FnSpecs (
333
+ FnName .DEQUANTIZE_MXFP8 .name , dequantize_mxfp8_fn , (), ()
334
+ )
307
335
308
336
test_bwd = False
309
337
weight_dtype = dtype_str_to_torch (weight_dtype_str )
310
338
act_dtype = dtype_str_to_torch (act_dtype_str )
311
- act_is_float8 = act_dtype .itemsize == 1
312
- precision_opt = init_precision (act_dtype , weight_dtype , is_mixed_input , n_expts_tot // n_expt_shards , device = device )
339
+ precision_opt = init_precision (act_dtype , act_is_float8 , weight_dtype , weight_mxfp , n_expts_tot // n_expt_shards , device = device )
313
340
# precision_opt.x_pad_trans_requires_flexpoint = False
314
341
if mode == "ragged" :
315
342
m , rdata , gindx , sindx = init_routing_data (m , n_expts_tot , n_expts_act , n_expt_shards , do_gather , do_scatter ,
316
343
device = device )
317
344
else :
318
345
rdata = gindx = sindx = None
319
346
x_tri , w_tri , bias_tri , gs0_tri , gs1_tri = init_compute_data (m , n , k , gindx , sindx , n_expts_tot , n_expts_act ,
320
- n_expt_shards , mode , act_dtype , #
321
- torch .bfloat16 if is_mixed_input else weight_dtype ,
347
+ n_expt_shards , mode , torch . bfloat16 if act_mxfp8 else act_dtype , #
348
+ torch .bfloat16 if weight_mxfp else weight_dtype ,
322
349
has_y_gammas , requires_grad = test_bwd , device = device )
323
350
x_ref , w_ref , bias_ref , gs0_ref , gs1_ref = apply_precision (x_tri , w_tri , bias_tri , gs0_tri , gs1_tri , precision_opt )
324
351
@@ -327,7 +354,7 @@ def test_op(m, n, k, split_k, do_gather, do_scatter, fused_scatter, has_y_gammas
327
354
w_tri = w_tri .squeeze (0 ).detach ().requires_grad_ (test_bwd )
328
355
w_ref = w_ref .squeeze (0 ).detach ().requires_grad_ (test_bwd )
329
356
330
- if is_mixed_input :
357
+ if weight_mxfp :
331
358
mx_axis = w_tri .ndim - 2
332
359
# compute layouts
333
360
w_layout , w_layout_opts = layout .StridedLayout , dict ()
@@ -346,6 +373,25 @@ def test_op(m, n, k, split_k, do_gather, do_scatter, fused_scatter, has_y_gammas
346
373
w_tri = convert_layout (w_tri , w_layout , ** w_layout_opts )
347
374
w_scale_tri = convert_layout (w_scale_tri , w_scale_layout , ** w_scale_layout_opts )
348
375
precision_opt .weight_scale = w_scale_tri
376
+ epilogue = None
377
+ if act_mxfp8 :
378
+ x_tri , x_mx_scales_tri = downcast_to_mxfp (x_tri , act_dtype , axis = - 1 )
379
+ x_ref = upcast_from_mxfp (x_tri , x_mx_scales_tri , torch .bfloat16 , axis = - 1 )
380
+ is_input_batched = x_tri .ndim == 3
381
+ y_shape = x_tri .shape if is_input_batched else (1 ,) + x_tri .shape
382
+ n_rows = y_shape [1 ] if gindx is None or mode == "batched" else gindx .dst_indx .shape [0 ]
383
+ y_shape = (y_shape [0 ], n_rows , w_tri .shape [- 1 ])
384
+ if sindx is None or mode == "batched" :
385
+ if not is_input_batched :
386
+ y_shape = (y_shape [1 ], y_shape [2 ])
387
+ else :
388
+ y_shape = (n_rows // rdata .n_expts_act , y_shape [- 1 ])
389
+ y_scale_shape = y_shape [:- 1 ] + (triton .cdiv (y_shape [- 1 ], MXFP_BLOCK_SIZE ),)
390
+ y_scale = torch .empty (y_scale_shape , dtype = torch .uint8 , device = x_tri .device )
391
+ precision_opt = replace (precision_opt , act_scale = x_mx_scales_tri , out_scale = y_scale )
392
+ epilogue = Epilogue (dequantize_mxfp8_spec , tuple (), tuple (), effective_itemsize = 6.0 )
393
+ else :
394
+ y_scale = None
349
395
350
396
if test_launch_metadata :
351
397
@@ -393,7 +439,7 @@ def _hook(launch_metadata):
393
439
394
440
# triton
395
441
try :
396
- tri_y = matmul_ogs (x_tri , w_tri , bias_tri , rdata , gindx , sindx , precision_opt , gammas = gs1_ref )
442
+ tri_y = matmul_ogs (x_tri , w_tri , bias_tri , rdata , gindx , sindx , precision_opt , gammas = gs1_ref , epilogue = epilogue )
397
443
except (opt_flags .InapplicableConstraint , NotImplementedError ):
398
444
pytest .skip ("inapplicable opt_flags constraint" )
399
445
# If split_k > 1, then the intermediate tensor is fp32.
@@ -432,7 +478,16 @@ def round_x(x, idx):
432
478
assert n_rows > 0
433
479
ref_y = ref_y [:n_rows ]
434
480
tri_y = tri_y [:n_rows ]
435
- assert_close (scale (ref_y , flex .out_data .expected_scale ), tri_y )
481
+ if act_mxfp8 :
482
+ tri_y = upcast_from_mxfp (tri_y , precision_opt .out_scale , dtype = torch .bfloat16 , axis = - 1 ).to (ref_y .dtype )
483
+ ref_y_quant , ref_y_scale = downcast_to_mxfp_torch (ref_y , act_dtype , axis = - 1 )
484
+ ref_y = upcast_from_mxfp_torch (ref_y_quant , ref_y_scale , target_dtype = ref_y .dtype , axis = - 1 )
485
+ maxtol = 4e-1
486
+ rmstol = 4e-2
487
+ else :
488
+ maxtol = None
489
+ rmstol = None
490
+ assert_close (scale (ref_y , flex .out_data .expected_scale ), tri_y , maxtol = maxtol , rmstol = rmstol )
436
491
437
492
if act_is_float8 :
438
493
tri_y_scale = flex .out_data .actual_scale .clone ()
@@ -495,7 +550,7 @@ def test_fused_act(m, n, k, mode, split_k, do_gather, do_scatter, fused_scatter,
495
550
else :
496
551
rdata = gindx = sindx = None
497
552
498
- precision_opt = init_precision (act_dtype , weight_dtype , False , n_expts_tot // n_expt_shards , device = device )
553
+ precision_opt = init_precision (act_dtype , str ( act_dtype ). startswith ( "torch.float8" ), weight_dtype , False , n_expts_tot // n_expt_shards , device = device )
499
554
x , w , bias , _ , _ = init_compute_data (m , n , k , gindx , sindx , n_expts_tot , n_expts_act , n_expt_shards , mode ,
500
555
act_dtype , weight_dtype , False , requires_grad = False , device = device )
501
556
0 commit comments