17
17
# testing utilities
18
18
from triton_kernels .testing import assert_close , compute_actual_scale
19
19
# target-specific utilities
20
- from triton_kernels .target_info import is_hip
20
+ from triton_kernels .target_info import is_hip , is_hip_cdna3 , is_cuda , is_hip_cdna4
21
21
22
22
# ---------------
23
23
# initialize data
@@ -75,18 +75,19 @@ def init_compute_data(m, n, k, gindx, sindx, n_expts_tot, n_expts_act, n_expt_sh
75
75
# ---------------
76
76
77
77
78
- def init_precision (out_dtype , act_use_flexpoint , weight_use_flexpoint , n_expts_tot = 1 , mx_ctx = MicroscalingCtx (),
79
- device = "cuda" ):
78
+ def init_precision (out_dtype , weight_dtype , is_mixed_input , n_expts_tot = 1 , mx_ctx = MicroscalingCtx (), device = "cuda" ):
79
+ act_use_flexpoint = out_dtype .itemsize == 1
80
+ weight_use_flexpoint = weight_dtype .itemsize == 1 and not is_mixed_input
80
81
# flexpoint
81
82
make_tensor = lambda val0 , val1 : torch .tensor ([val0 , val1 ] * (n_expts_tot // 2 ) +
82
83
([val0 ]
83
84
if n_expts_tot % 2 else []), dtype = torch .float32 , device = device )
84
85
make_scalar = lambda val : torch .tensor ([val ], dtype = torch .float32 , device = device )
85
- in_flex_data = lambda scale , use_flex : InFlexData (dtype = torch . float8_e5m2 , scale = make_scalar (scale )
86
+ in_flex_data = lambda scale , use_flex : InFlexData (dtype = out_dtype , scale = make_scalar (scale )
86
87
) if use_flex else InFlexData ()
87
- in_flex_edata = lambda scale0 , scale1 , use_flex : InFlexData (dtype = torch . float8_e5m2 , scale = make_tensor (
88
- scale0 , scale1 ) ) if use_flex else InFlexData ()
89
- out_flex_data = lambda scale , use_flex : OutFlexData (dtype = torch . float8_e5m2 , expected_scale = make_scalar (
88
+ in_flex_edata = lambda scale0 , scale1 , use_flex : InFlexData (dtype = weight_dtype , scale = make_tensor (scale0 , scale1 )
89
+ ) if use_flex else InFlexData ()
90
+ out_flex_data = lambda scale , use_flex : OutFlexData (dtype = out_dtype , expected_scale = make_scalar (
90
91
scale ), actual_scale = make_scalar (0 ), checksum_scale = make_scalar (0 )) if use_flex else OutFlexData ()
91
92
flex_ctx = FlexCtx (
92
93
lhs_data = in_flex_data (1.25 , act_use_flexpoint ),
@@ -211,8 +212,11 @@ class Case:
211
212
Case (1000 , 400 , 400 , "ragged" , "float8_e4m3fnuz" , "float8_e4m3fnuz" , 3 , 1 ),
212
213
Case (600 , 400 , 400 , "ragged" , "float8_e4m3fnuz" , "float8_e4m3fnuz" , 4 , 2 ),
213
214
Case (600 , 400 , 400 , "ragged" , "float8_e4m3fnuz" , "float8_e4m3fnuz" , 4 , 2 , n_expt_shards = 2 ),
214
- Case (600 , 400 , 400 , "ragged" , "float8_e4m3fnuz" , "float8_e4m3fnuz" , 4 , 2 ),
215
215
Case (600 , 400 , 400 , "ragged" , "float8_e4m3fnuz" , "float8_e4m3fnuz" , 4 , 2 , split_k = 2 ),
216
+ Case (300 , 400 , 400 , "ragged" , "float8_e4m3fn" , "float8_e4m3fn" ),
217
+ Case (1000 , 400 , 400 , "ragged" , "float8_e4m3fn" , "float8_e4m3fn" , 3 , 1 ),
218
+ Case (600 , 400 , 400 , "ragged" , "float8_e4m3fn" , "float8_e4m3fn" , 4 , 2 ),
219
+ Case (600 , 400 , 400 , "ragged" , "float8_e4m3fn" , "float8_e4m3fn" , 4 , 2 , n_expt_shards = 2 ),
216
220
]
217
221
],
218
222
)
@@ -230,16 +234,26 @@ def test_op(m, n, k, split_k, do_gather, do_scatter, fused_scatter, has_y_gammas
230
234
n_expts_act , n_expt_shards , mode , act_dtype_str , weight_dtype_str , block_m , hbm_swizzling , epilogue_subtile ,
231
235
device , opt_flags_scope ):
232
236
# TODO: remove when Triton FP8 supports proper RTNE
233
- if "float8" in weight_dtype_str and torch .cuda .get_device_capability ()[0 ] < 9 :
234
- pytest .skip ("Float8 not tested on A100" )
235
- if "float8_e4m3fnuz" in weight_dtype_str and not is_hip ():
236
- pytest .skip ("float8_e4m3fnuz only tested on HIP platforms" )
237
- if "mx" in weight_dtype_str and is_hip ():
238
- pytest .skip ("mxfloat* only tested on CUDA platforms" )
239
- if "float16" in act_dtype_str and "mx" in weight_dtype_str and torch .cuda .get_device_capability ()[0 ] >= 10 :
240
- pytest .skip ("float16 x mx not supported with cuda capability >= 10" )
241
- if "float8" in act_dtype_str and "mx" in weight_dtype_str and torch .cuda .get_device_capability ()[0 ] < 10 :
242
- pytest .skip ("float8 x mx not supported with cuda capability < 10" )
237
+ if is_cuda ():
238
+ if "float8" in weight_dtype_str and torch .cuda .get_device_capability ()[0 ] < 9 :
239
+ pytest .skip ("Float8 not tested on A100" )
240
+ if "float16" in act_dtype_str and "mx" in weight_dtype_str and torch .cuda .get_device_capability ()[0 ] >= 10 :
241
+ pytest .skip ("float16 x mx not supported with cuda capability >= 10" )
242
+ if "float8" in act_dtype_str and "mx" in weight_dtype_str and torch .cuda .get_device_capability ()[0 ] < 10 :
243
+ pytest .skip ("float8 x mx not supported with cuda capability < 10" )
244
+ elif is_hip ():
245
+ if "float8" in act_dtype_str and "mx" in weight_dtype_str and not is_hip_cdna4 ():
246
+ pytest .skip ("float8 x mx only supported on CDNA4" )
247
+ if "float8" in act_dtype_str and "mxfloat8" in weight_dtype_str :
248
+ pytest .skip ("NYI: float8 x mxfloat8 not tested on AMD GPU" )
249
+ if is_persistent :
250
+ pytest .skip ("NYI: Persistent kernel not supported on AMD GPU" )
251
+ if split_k > 1 :
252
+ pytest .skip ("splitK hasn't been fully tested on AMD GPU." )
253
+
254
+ if "float8_e4m3fnuz" in (weight_dtype_str , act_dtype_str ) and not is_hip_cdna3 ():
255
+ pytest .skip ("float8_e4m3fnuz only tested on AMD CDNA3 Platform" )
256
+
243
257
if fused_scatter and split_k > 1 :
244
258
pytest .skip ("fused scatter scratchpad not supported with split_k" )
245
259
if hbm_swizzling :
@@ -284,9 +298,7 @@ def test_op(m, n, k, split_k, do_gather, do_scatter, fused_scatter, has_y_gammas
284
298
weight_dtype = dtype_str_to_torch (weight_dtype_str )
285
299
act_dtype = dtype_str_to_torch (act_dtype_str )
286
300
act_is_float8 = act_dtype .itemsize == 1
287
- weight_is_float8 = weight_dtype .itemsize == 1
288
- precision_opt = init_precision (act_dtype , act_is_float8 , weight_is_float8 and not is_mixed_input ,
289
- n_expts_tot // n_expt_shards , device = device )
301
+ precision_opt = init_precision (act_dtype , weight_dtype , is_mixed_input , n_expts_tot // n_expt_shards , device = device )
290
302
# precision_opt.x_pad_trans_requires_flexpoint = False
291
303
if mode == "ragged" :
292
304
m , rdata , gindx , sindx = init_routing_data (m , n_expts_tot , n_expts_act , n_expt_shards , do_gather , do_scatter ,
@@ -456,7 +468,7 @@ def test_fused_act(m, n, k, mode, split_k, do_gather, do_scatter, fused_scatter,
456
468
else :
457
469
rdata = gindx = sindx = None
458
470
459
- precision_opt = init_precision (act_dtype , False , False , n_expts_tot // n_expt_shards , device = device )
471
+ precision_opt = init_precision (act_dtype , weight_dtype , False , n_expts_tot // n_expt_shards , device = device )
460
472
x , w , bias , _ , _ = init_compute_data (m , n , k , gindx , sindx , n_expts_tot , n_expts_act , n_expt_shards , mode ,
461
473
act_dtype , weight_dtype , False , requires_grad = False , device = device )
462
474
0 commit comments