1- from dataclasses import dataclass , fields
1+ # isort: off
2+ # fmt: off
3+ from dataclasses import dataclass , fields , replace
24import pytest
35import torch
46from typing import Union
79from triton_kernels .routing import routing
810# matmul utilities
911import triton_kernels .matmul_ogs_details .opt_flags as opt_flags
10- from triton_kernels .matmul_ogs import FlexCtx , PrecisionConfig , FusedActivation , FnSpecs
12+ from triton_kernels .matmul_ogs import FlexCtx , PrecisionConfig , FusedActivation , FnSpecs , FnName , Epilogue
1113from triton_kernels .matmul_ogs import matmul_ogs_set_idle_sms , matmul_ogs , matmul_ogs_torch
1214from triton_kernels .swiglu import swiglu , swiglu_fn , PrecisionConfig as SwiGLUPrecisionConfig
1315from triton_kernels .tensor import convert_layout , wrap_torch_tensor , FP4
1416from triton_kernels .tensor_details import layout
1517# numerics utilities
1618from triton_kernels .numerics import InFlexData , OutFlexData
17- from triton_kernels .numerics_details .mxfp import downcast_to_mxfp , upcast_from_mxfp
19+ from triton_kernels .numerics_details .mxfp import downcast_to_mxfp , upcast_from_mxfp , dequantize_mxfp8_fn , downcast_to_mxfp_torch , upcast_from_mxfp_torch , MXFP_BLOCK_SIZE
1820# testing utilities
1921from triton_kernels .testing import assert_close , compute_actual_scale
2022# target-specific utilities
@@ -78,9 +80,8 @@ def init_compute_data(m, n, k, gindx, sindx, n_expts_tot, n_expts_act, n_expt_sh
7880# ---------------
7981
8082
81- def init_precision (out_dtype , weight_dtype , is_mixed_input , n_expts_tot = 1 , device = "cuda" ):
82- act_use_flexpoint = out_dtype .itemsize == 1
83- weight_use_flexpoint = weight_dtype .itemsize == 1 and not is_mixed_input
83+ def init_precision (out_dtype , act_use_flexpoint , weight_dtype , weight_mxfp , n_expts_tot = 1 , device = "cuda" ):
84+ weight_use_flexpoint = weight_dtype .itemsize == 1 and not weight_mxfp
8485 # flexpoint
8586 make_tensor = lambda val0 , val1 : torch .tensor ([val0 , val1 ] * (n_expts_tot // 2 ) +
8687 ([val0 ]
@@ -106,13 +107,14 @@ def apply_precision(x_tri, w_tri, bias_tri, gs0_tri, gs1_tri, precision_config):
106107
107108 def apply (x , scale ):
108109 if scale is None :
109- return x .clone (). detach (). requires_grad_ ( True )
110+ x = x .clone ()
110111 elif scale .numel () == 1 :
111- return ( x .float () * scale ). detach (). requires_grad_ ( True )
112+ x = x .float () * scale
112113 else :
113114 assert x .ndim == 3
114115 assert scale .numel () == x .shape [0 ]
115- return (x .float () * scale [:, None , None ]).detach ().requires_grad_ (True )
116+ x = x .float () * scale [:, None , None ]
117+ return x .detach ().requires_grad_ ()
116118
117119 return (
118120 apply (x_tri , flex_ctx .lhs_data .scale ),
@@ -215,6 +217,19 @@ class Case:
215217 Case (300 , 400 , 400 , "batched" , "float8_e5m2" , "mxfloat8_e4m3fn" , 32 , 4 , hbm_swizzling = True ),
216218 Case (256 , 256 , 256 , "ragged" , "float8_e5m2" , "mxfloat4_e2m1" , 128 , 4 , hbm_swizzling = True ),
217219 Case (256 , 256 , 256 , "ragged" , "float8_e5m2" , "mxfloat4_e2m1" , 128 , 4 , hbm_swizzling = False ),
220+ Case (16 , 256 , 256 , "ragged" , "mxfloat8_e4m3fn" , "mxfloat4_e2m1" , 128 , 4 , hbm_swizzling = True ),
221+ Case (1000 , 704 , 800 , "batched" , "mxfloat8_e4m3fn" , "mxfloat4_e2m1" , 3 , 1 , hbm_swizzling = True ),
222+ Case (1000 , 704 , 800 , "batched" , "mxfloat8_e4m3fn" , "mxfloat4_e2m1" , 2 , 1 ),
223+ Case (1000 , 704 , 800 , "ragged" , "mxfloat8_e4m3fn" , "mxfloat4_e2m1" , 8 , 2 , split_k = 9 ),
224+ Case (1000 , 704 , 800 , "ragged" , "mxfloat8_e4m3fn" , "mxfloat4_e2m1" , 8 , 2 , split_k = 9 , hbm_swizzling = True ),
225+ Case (1000 , 704 , 800 , "ragged" , "mxfloat8_e4m3fn" , "mxfloat4_e2m1" , 8 , 2 ),
226+ Case (1000 , 704 , 800 , "ragged" , "mxfloat8_e4m3fn" , "mxfloat4_e2m1" , 8 , 2 , hbm_swizzling = True ),
227+ Case (300 , 400 , 400 , "ragged" , "mxfloat8_e4m3fn" , "mxfloat8_e4m3fn" , 8 , 4 ),
228+ Case (300 , 400 , 400 , "ragged" , "mxfloat8_e4m3fn" , "mxfloat8_e4m3fn" , 8 , 4 , hbm_swizzling = True ),
229+ Case (300 , 400 , 800 , "ragged" , "mxfloat8_e4m3fn" , "mxfloat4_e2m1" , 8 , 4 ),
230+ Case (300 , 400 , 800 , "ragged" , "mxfloat8_e4m3fn" , "mxfloat4_e2m1" , 8 , 4 , hbm_swizzling = True ),
231+ Case (300 , 400 , 400 , "batched" , "mxfloat8_e4m3fn" , "mxfloat8_e4m3fn" , 32 , 4 ),
232+ Case (300 , 400 , 400 , "batched" , "mxfloat8_e4m3fn" , "mxfloat8_e4m3fn" , 32 , 4 , hbm_swizzling = True ),
218233 # AMD
219234 Case (300 , 400 , 400 , "ragged" , "float8_e4m3fnuz" , "float8_e4m3fnuz" ),
220235 Case (1000 , 400 , 400 , "ragged" , "float8_e4m3fnuz" , "float8_e4m3fnuz" , 3 , 1 ),
@@ -247,8 +262,12 @@ def test_op(m, n, k, split_k, do_gather, do_scatter, fused_scatter, has_y_gammas
247262 pytest .skip ("Float8 not tested on A100" )
248263 if "float16" in act_dtype_str and "mx" in weight_dtype_str and torch .cuda .get_device_capability ()[0 ] >= 10 :
249264 pytest .skip ("float16 x mx not supported with cuda capability >= 10" )
250- if "float8" in act_dtype_str and "mx" in weight_dtype_str and torch .cuda .get_device_capability ()[0 ] < 10 :
251- pytest .skip ("float8 x mx not supported with cuda capability < 10" )
265+ if weight_dtype_str .startswith ("mx" ):
266+ if "float8" in act_dtype_str and torch .cuda .get_device_capability ()[0 ] < 10 :
267+ pytest .skip ("float8 x mx not supported with cuda capability < 10" )
268+ if act_dtype_str == "mxfloat8_e4m3fn" :
269+ if is_persistent :
270+ pytest .skip ("mx x mx not supported with persistent kernel" )
252271 if n == 2880 and k == 2880 and torch .cuda .get_device_capability ()[0 ] < 9 :
253272 pytest .skip ("Not enough memory on A100" )
254273
@@ -257,6 +276,8 @@ def test_op(m, n, k, split_k, do_gather, do_scatter, fused_scatter, has_y_gammas
257276 pytest .skip ("float8 x mx only supported on CDNA4" )
258277 if "float8" in act_dtype_str and "mxfloat8" in weight_dtype_str :
259278 pytest .skip ("NYI: float8 x mxfloat8 not tested on AMD GPU" )
279+ if act_dtype_str .startswith ("mx" ) and weight_dtype_str .startswith ("mx" ):
280+ pytest .skip ("NYI: mx x mx not tested on AMD GPU" )
260281 if is_persistent :
261282 pytest .skip ("NYI: Persistent kernel not supported on AMD GPU" )
262283 if split_k > 1 :
@@ -301,24 +322,30 @@ def test_op(m, n, k, split_k, do_gather, do_scatter, fused_scatter, has_y_gammas
301322 }
302323 opt_flags .update_opt_flags_constraints (constraints )
303324
304- is_mixed_input = act_dtype_str != weight_dtype_str
305- if weight_dtype_str . startswith ( "mx" ) :
325+ weight_mxfp = weight_dtype_str . startswith ( "mx" )
326+ if weight_mxfp :
306327 weight_dtype_str = weight_dtype_str [2 :]
328+ act_mxfp8 = act_dtype_str .startswith ("mx" )
329+ act_is_float8 = act_dtype_str .startswith ("float8" )
330+ if act_mxfp8 :
331+ act_dtype_str = act_dtype_str [2 :]
332+ dequantize_mxfp8_spec = FnSpecs (
333+ FnName .DEQUANTIZE_MXFP8 .name , dequantize_mxfp8_fn , (), ()
334+ )
307335
308336 test_bwd = False
309337 weight_dtype = dtype_str_to_torch (weight_dtype_str )
310338 act_dtype = dtype_str_to_torch (act_dtype_str )
311- act_is_float8 = act_dtype .itemsize == 1
312- precision_opt = init_precision (act_dtype , weight_dtype , is_mixed_input , n_expts_tot // n_expt_shards , device = device )
339+ precision_opt = init_precision (act_dtype , act_is_float8 , weight_dtype , weight_mxfp , n_expts_tot // n_expt_shards , device = device )
313340 # precision_opt.x_pad_trans_requires_flexpoint = False
314341 if mode == "ragged" :
315342 m , rdata , gindx , sindx = init_routing_data (m , n_expts_tot , n_expts_act , n_expt_shards , do_gather , do_scatter ,
316343 device = device )
317344 else :
318345 rdata = gindx = sindx = None
319346 x_tri , w_tri , bias_tri , gs0_tri , gs1_tri = init_compute_data (m , n , k , gindx , sindx , n_expts_tot , n_expts_act ,
320- n_expt_shards , mode , act_dtype , #
321- torch .bfloat16 if is_mixed_input else weight_dtype ,
347+ n_expt_shards , mode , torch . bfloat16 if act_mxfp8 else act_dtype , #
348+ torch .bfloat16 if weight_mxfp else weight_dtype ,
322349 has_y_gammas , requires_grad = test_bwd , device = device )
323350 x_ref , w_ref , bias_ref , gs0_ref , gs1_ref = apply_precision (x_tri , w_tri , bias_tri , gs0_tri , gs1_tri , precision_opt )
324351
@@ -327,7 +354,7 @@ def test_op(m, n, k, split_k, do_gather, do_scatter, fused_scatter, has_y_gammas
327354 w_tri = w_tri .squeeze (0 ).detach ().requires_grad_ (test_bwd )
328355 w_ref = w_ref .squeeze (0 ).detach ().requires_grad_ (test_bwd )
329356
330- if is_mixed_input :
357+ if weight_mxfp :
331358 mx_axis = w_tri .ndim - 2
332359 # compute layouts
333360 w_layout , w_layout_opts = layout .StridedLayout , dict ()
@@ -346,6 +373,25 @@ def test_op(m, n, k, split_k, do_gather, do_scatter, fused_scatter, has_y_gammas
346373 w_tri = convert_layout (w_tri , w_layout , ** w_layout_opts )
347374 w_scale_tri = convert_layout (w_scale_tri , w_scale_layout , ** w_scale_layout_opts )
348375 precision_opt .weight_scale = w_scale_tri
376+ epilogue = None
377+ if act_mxfp8 :
378+ x_tri , x_mx_scales_tri = downcast_to_mxfp (x_tri , act_dtype , axis = - 1 )
379+ x_ref = upcast_from_mxfp (x_tri , x_mx_scales_tri , torch .bfloat16 , axis = - 1 )
380+ is_input_batched = x_tri .ndim == 3
381+ y_shape = x_tri .shape if is_input_batched else (1 ,) + x_tri .shape
382+ n_rows = y_shape [1 ] if gindx is None or mode == "batched" else gindx .dst_indx .shape [0 ]
383+ y_shape = (y_shape [0 ], n_rows , w_tri .shape [- 1 ])
384+ if sindx is None or mode == "batched" :
385+ if not is_input_batched :
386+ y_shape = (y_shape [1 ], y_shape [2 ])
387+ else :
388+ y_shape = (n_rows // rdata .n_expts_act , y_shape [- 1 ])
389+ y_scale_shape = y_shape [:- 1 ] + (triton .cdiv (y_shape [- 1 ], MXFP_BLOCK_SIZE ),)
390+ y_scale = torch .empty (y_scale_shape , dtype = torch .uint8 , device = x_tri .device )
391+ precision_opt = replace (precision_opt , act_scale = x_mx_scales_tri , out_scale = y_scale )
392+ epilogue = Epilogue (dequantize_mxfp8_spec , tuple (), tuple (), effective_itemsize = 6.0 )
393+ else :
394+ y_scale = None
349395
350396 if test_launch_metadata :
351397
@@ -393,7 +439,7 @@ def _hook(launch_metadata):
393439
394440 # triton
395441 try :
396- tri_y = matmul_ogs (x_tri , w_tri , bias_tri , rdata , gindx , sindx , precision_opt , gammas = gs1_ref )
442+ tri_y = matmul_ogs (x_tri , w_tri , bias_tri , rdata , gindx , sindx , precision_opt , gammas = gs1_ref , epilogue = epilogue )
397443 except (opt_flags .InapplicableConstraint , NotImplementedError ):
398444 pytest .skip ("inapplicable opt_flags constraint" )
399445 # If split_k > 1, then the intermediate tensor is fp32.
@@ -432,7 +478,16 @@ def round_x(x, idx):
432478 assert n_rows > 0
433479 ref_y = ref_y [:n_rows ]
434480 tri_y = tri_y [:n_rows ]
435- assert_close (scale (ref_y , flex .out_data .expected_scale ), tri_y )
481+ if act_mxfp8 :
482+ tri_y = upcast_from_mxfp (tri_y , precision_opt .out_scale , dtype = torch .bfloat16 , axis = - 1 ).to (ref_y .dtype )
483+ ref_y_quant , ref_y_scale = downcast_to_mxfp_torch (ref_y , act_dtype , axis = - 1 )
484+ ref_y = upcast_from_mxfp_torch (ref_y_quant , ref_y_scale , target_dtype = ref_y .dtype , axis = - 1 )
485+ maxtol = 4e-1
486+ rmstol = 4e-2
487+ else :
488+ maxtol = None
489+ rmstol = None
490+ assert_close (scale (ref_y , flex .out_data .expected_scale ), tri_y , maxtol = maxtol , rmstol = rmstol )
436491
437492 if act_is_float8 :
438493 tri_y_scale = flex .out_data .actual_scale .clone ()
@@ -495,7 +550,7 @@ def test_fused_act(m, n, k, mode, split_k, do_gather, do_scatter, fused_scatter,
495550 else :
496551 rdata = gindx = sindx = None
497552
498- precision_opt = init_precision (act_dtype , weight_dtype , False , n_expts_tot // n_expt_shards , device = device )
553+ precision_opt = init_precision (act_dtype , str ( act_dtype ). startswith ( "torch.float8" ), weight_dtype , False , n_expts_tot // n_expt_shards , device = device )
499554 x , w , bias , _ , _ = init_compute_data (m , n , k , gindx , sindx , n_expts_tot , n_expts_act , n_expt_shards , mode ,
500555 act_dtype , weight_dtype , False , requires_grad = False , device = device )
501556
0 commit comments