7
7
from triton_kernels .routing import routing
8
8
# matmul utilities
9
9
import triton_kernels .matmul_ogs_details .opt_flags as opt_flags
10
- from triton_kernels .matmul_ogs import FlexCtx , PrecisionConfig , MicroscalingCtx , FusedActivation , FnSpecs
11
- from triton_kernels .matmul_ogs import can_use_persistent_tma
10
+ from triton_kernels .matmul_ogs import FlexCtx , PrecisionConfig , FusedActivation , FnSpecs
12
11
from triton_kernels .matmul_ogs import matmul_ogs_set_idle_sms , matmul_ogs , matmul_ogs_torch
13
12
from triton_kernels .swiglu import swiglu , swiglu_fn , PrecisionConfig as SwiGLUPrecisionConfig
13
+ from triton_kernels .tensor import convert_layout , wrap_torch_tensor , FP4
14
+ from triton_kernels .tensor_details import layout
14
15
# numerics utilities
15
16
from triton_kernels .numerics import InFlexData , OutFlexData
16
- from triton_kernels .numerics_details .mxfp import SwizzlingType , downcast_to_mxfp , upcast_from_mxfp
17
+ from triton_kernels .numerics_details .mxfp import downcast_to_mxfp , upcast_from_mxfp
17
18
# testing utilities
18
19
from triton_kernels .testing import assert_close , compute_actual_scale
19
20
# target-specific utilities
@@ -53,20 +54,22 @@ def init_routing_data(m, n_expts_tot, n_expts_act, n_expt_shards, do_gather, do_
53
54
def init_compute_data (m , n , k , gindx , sindx , n_expts_tot , n_expts_act , n_expt_shards , mode , act_dtype , weight_dtype ,
54
55
has_y_gammas , requires_grad = True , device = "cuda" ):
55
56
torch .manual_seed (0 )
56
- assert mode in {'batched' , 'ragged' }
57
+ assert mode in {'batched' , "plain" , 'ragged' }
57
58
in_m = m * (n_expts_act if gindx is None else 1 )
58
59
shape_x = (n_expts_tot , in_m , k ) if mode == 'batched' else (in_m , k )
60
+ shape_batch = tuple () if mode == "plain" else (n_expts_tot // n_expt_shards , )
59
61
x = alloc_rand (shape_x , device = device , dtype = act_dtype , requires_grad = requires_grad )
60
- w = alloc_rand ((n_expts_tot // n_expt_shards , k , n ), device = device , dtype = weight_dtype , requires_grad = requires_grad )
61
- bias = alloc_rand ((n_expts_tot // n_expt_shards , n ), device = device , dtype = torch .float32 ,
62
- requires_grad = requires_grad )
62
+ w = alloc_rand (shape_batch + (k , n ), device = device , dtype = weight_dtype , requires_grad = requires_grad )
63
+ bias = alloc_rand (shape_batch + (n , ), device = device , dtype = torch .float32 , requires_grad = requires_grad )
63
64
gs0 = 2 ** torch .randint (- 5 , 0 , (m * n_expts_act , ), device = device , dtype = torch .float32 , requires_grad = requires_grad )
64
65
gs1 = 2 ** torch .randint (- 5 , 0 , (m * n_expts_act , ), device = device , dtype = torch .float32 , requires_grad = requires_grad )
65
66
gs0 = gs0 .detach ().requires_grad_ (requires_grad )
66
67
gs1 = gs1 .detach ().requires_grad_ (requires_grad )
67
68
if mode == 'batched' or (not has_y_gammas ) or (has_y_gammas and (gindx is not None ) and act_dtype .itemsize >= 2 ):
68
69
gs0 = None
69
70
gs1 = None
71
+ if "float8" in str (weight_dtype ) and torch .cuda .get_device_capability ()[0 ] < 10 :
72
+ w = w .transpose (- 1 , - 2 ).contiguous ().transpose (- 1 , - 2 )
70
73
return x , w , bias , gs0 , gs1
71
74
72
75
@@ -75,7 +78,7 @@ def init_compute_data(m, n, k, gindx, sindx, n_expts_tot, n_expts_act, n_expt_sh
75
78
# ---------------
76
79
77
80
78
- def init_precision (out_dtype , weight_dtype , is_mixed_input , n_expts_tot = 1 , mx_ctx = MicroscalingCtx (), device = "cuda" ):
81
+ def init_precision (out_dtype , weight_dtype , is_mixed_input , n_expts_tot = 1 , device = "cuda" ):
79
82
act_use_flexpoint = out_dtype .itemsize == 1
80
83
weight_use_flexpoint = weight_dtype .itemsize == 1 and not is_mixed_input
81
84
# flexpoint
@@ -95,7 +98,7 @@ def init_precision(out_dtype, weight_dtype, is_mixed_input, n_expts_tot=1, mx_ct
95
98
out_data = out_flex_data (4.00 , act_use_flexpoint ),
96
99
)
97
100
return PrecisionConfig (flex_ctx = flex_ctx , acc_scale = 2.0 if act_use_flexpoint or weight_use_flexpoint else 1.0 ,
98
- mx_ctx = mx_ctx , out_dtype = out_dtype )
101
+ out_dtype = out_dtype )
99
102
100
103
101
104
def apply_precision (x_tri , w_tri , bias_tri , gs0_tri , gs1_tri , precision_config ):
@@ -183,8 +186,10 @@ class Case:
183
186
Case (1000 , 700 , 700 , "ragged" , "float16" , "float16" , 8 , 2 ),
184
187
Case (1000 , 700 , 700 , "ragged" , "float16" , "float16" , 8 , 2 , split_k = 9 ),
185
188
# mx types:
186
- Case (16 , 256 , 256 , "ragged" , "bfloat16" , "mxfloat4_e2m1" , 128 , 4 ),
187
- Case (16 , 256 , 256 , "ragged" , "bfloat16" , "mxfloat4_e2m1" , 128 , 4 , hbm_swizzling = True ),
189
+ Case (16 , 256 , 256 , "plain" , "bfloat16" , "mxfloat4_e2m1" , 1 , 1 ),
190
+ Case (16 , 256 , 256 , "plain" , "bfloat16" , "mxfloat4_e2m1" , 1 , 1 , hbm_swizzling = True ),
191
+ Case (16 , 256 , 256 , "ragged" , "bfloat16" , "mxfloat4_e2m1" , 1 , 1 ),
192
+ Case (16 , 256 , 256 , "ragged" , "bfloat16" , "mxfloat4_e2m1" , 1 , 1 , hbm_swizzling = True ),
188
193
Case (1000 , 700 , 700 , "batched" , "bfloat16" , "mxfloat4_e2m1" , 8 , 2 ),
189
194
Case (1000 , 700 , 700 , "batched" , "bfloat16" , "mxfloat4_e2m1" , 8 , 2 , hbm_swizzling = True ),
190
195
Case (1000 , 700 , 700 , "ragged" , "bfloat16" , "mxfloat4_e2m1" , 8 , 2 , split_k = 9 ),
@@ -198,10 +203,10 @@ class Case:
198
203
Case (1000 , 704 , 832 , "batched" , "float8_e5m2" , "mxfloat4_e2m1" , 3 , 1 , hbm_swizzling = True ),
199
204
Case (1000 , 704 , 832 , "batched" , "float8_e5m2" , "mxfloat4_e2m1" , 3 , 1 , hbm_swizzling = True ),
200
205
Case (1000 , 704 , 832 , "batched" , "float8_e5m2" , "mxfloat4_e2m1" , 3 , 1 ),
201
- Case (1000 , 704 , 832 , "ragged" , "float8_e5m2" , "mxfloat4_e2m1" , 8 , 2 , split_k = 9 ),
202
- Case (1000 , 704 , 832 , "ragged" , "float8_e5m2" , "mxfloat4_e2m1" , 8 , 2 , split_k = 9 , hbm_swizzling = True ),
203
- Case (1000 , 704 , 832 , "ragged" , "float8_e5m2" , "mxfloat4_e2m1" , 8 , 2 ),
204
- Case (1000 , 704 , 832 , "ragged" , "float8_e5m2" , "mxfloat4_e2m1" , 8 , 2 , hbm_swizzling = True ),
206
+ Case (1000 , 704 , 800 , "ragged" , "float8_e5m2" , "mxfloat4_e2m1" , 8 , 2 , split_k = 9 ),
207
+ Case (1000 , 704 , 800 , "ragged" , "float8_e5m2" , "mxfloat4_e2m1" , 8 , 2 , split_k = 9 , hbm_swizzling = True ),
208
+ Case (1000 , 704 , 800 , "ragged" , "float8_e5m2" , "mxfloat4_e2m1" , 8 , 2 ),
209
+ Case (1000 , 704 , 800 , "ragged" , "float8_e5m2" , "mxfloat4_e2m1" , 8 , 2 , hbm_swizzling = True ),
205
210
Case (300 , 400 , 400 , "ragged" , "float8_e5m2" , "mxfloat8_e4m3fn" , 8 , 4 ),
206
211
Case (300 , 400 , 400 , "ragged" , "float8_e5m2" , "mxfloat8_e4m3fn" , 8 , 4 , hbm_swizzling = True ),
207
212
Case (300 , 400 , 832 , "ragged" , "float8_e5m2" , "mxfloat4_e2m1" , 8 , 4 ),
@@ -317,38 +322,32 @@ def test_op(m, n, k, split_k, do_gather, do_scatter, fused_scatter, has_y_gammas
317
322
has_y_gammas , requires_grad = test_bwd , device = device )
318
323
x_ref , w_ref , bias_ref , gs0_ref , gs1_ref = apply_precision (x_tri , w_tri , bias_tri , gs0_tri , gs1_tri , precision_opt )
319
324
320
- if is_mixed_input :
321
- if hbm_swizzling :
322
- swizzle_axis = 2
323
- if torch .cuda .get_device_capability ()[0 ] < 10 :
324
- swizzle_value = SwizzlingType .HOPPER
325
- swizzle_scale = SwizzlingType .HOPPER
326
- else :
327
- swizzle_value = None
328
- swizzle_scale = SwizzlingType .BLACKWELL
329
- else :
330
- swizzle_axis = None
331
- swizzle_value = None
332
- swizzle_scale = None
333
- w_tri , mx_scales_tri , weight_scale_shape = downcast_to_mxfp (w_tri , weight_dtype , axis = 1 ,
334
- swizzle_axis = swizzle_axis ,
335
- swizzle_value = swizzle_value ,
336
- swizzle_scale = swizzle_scale )
337
- w_ref = upcast_from_mxfp (w_tri , mx_scales_tri , torch .bfloat16 , axis = 1 , swizzle_axis = swizzle_axis ,
338
- swizzle_value = swizzle_value , swizzle_scale = swizzle_scale )
339
-
340
- precision_opt .mx_ctx = MicroscalingCtx (weight_scale = mx_scales_tri , swizzle_value = swizzle_value ,
341
- swizzle_scale = swizzle_scale ,
342
- actual_weight_scale_shape = weight_scale_shape )
343
-
344
- if is_persistent and not can_use_persistent_tma (x_tri , w_tri , gindx , precision_opt ):
345
- pytest .skip ("persistent TMAs not supported for this test" )
346
-
347
325
if w_tri .shape [0 ] == 1 :
348
326
# Test the case when weight has dim 2, i.e., shape (K, N).
349
327
w_tri = w_tri .squeeze (0 ).detach ().requires_grad_ (test_bwd )
350
328
w_ref = w_ref .squeeze (0 ).detach ().requires_grad_ (test_bwd )
351
329
330
+ if is_mixed_input :
331
+ capability_major = torch .cuda .get_device_capability ()[0 ]
332
+ w_layout = layout .StridedLayout
333
+ w_scale_layout = layout .StridedLayout
334
+ if hbm_swizzling and "float4" in weight_dtype_str :
335
+ # weight layout
336
+ w_layouts = {9 : layout .HopperMXValueLayout }
337
+ w_layout = w_layouts .get (capability_major , layout .StridedLayout )
338
+ # weight scale layout
339
+ w_scales_layouts = {9 : layout .HopperMXScaleLayout , 10 : layout .BlackwellMXScaleLayout }
340
+ w_scale_layout = w_scales_layouts .get (capability_major , layout .StridedLayout )
341
+ w_tri , mx_scales_tri = downcast_to_mxfp (w_tri , weight_dtype , axis = - 2 )
342
+ w_ref = upcast_from_mxfp (w_tri , mx_scales_tri , torch .bfloat16 , axis = - 2 )
343
+ w_tri_dtype = FP4 if "float4" in weight_dtype_str else weight_dtype
344
+ w_tri = convert_layout (wrap_torch_tensor (w_tri , w_tri_dtype ), w_layout )
345
+ mx_scales_tri = convert_layout (wrap_torch_tensor (mx_scales_tri ), w_scale_layout )
346
+ precision_opt .weight_scale = mx_scales_tri
347
+
348
+ # if not is_persistent and precision_opt.weight_scale is not None:
349
+ # pytest.skip("non-persistent not supported with mxfp")
350
+
352
351
if test_launch_metadata :
353
352
354
353
def _clobber (t , used_mask ):
@@ -394,7 +393,10 @@ def _hook(launch_metadata):
394
393
flex = precision_opt .flex_ctx
395
394
396
395
# triton
397
- tri_y = matmul_ogs (x_tri , w_tri , bias_tri , rdata , gindx , sindx , precision_opt , gammas = gs1_ref )
396
+ try :
397
+ tri_y = matmul_ogs (x_tri , w_tri , bias_tri , rdata , gindx , sindx , precision_opt , gammas = gs1_ref )
398
+ except (opt_flags .InapplicableConstraint , NotImplementedError ):
399
+ pytest .skip ("inapplicable opt_flags constraint" )
398
400
# If split_k > 1, then the intermediate tensor is fp32.
399
401
sep_gather = mode == "ragged" and do_gather and n_expts_act > 1 and split_k == 1
400
402
sep_scatter = mode == "ragged" and do_scatter and n_expts_act > 1 and split_k == 1
@@ -498,16 +500,16 @@ def test_fused_act(m, n, k, mode, split_k, do_gather, do_scatter, fused_scatter,
498
500
x , w , bias , _ , _ = init_compute_data (m , n , k , gindx , sindx , n_expts_tot , n_expts_act , n_expt_shards , mode ,
499
501
act_dtype , weight_dtype , False , requires_grad = False , device = device )
500
502
501
- if is_persistent and not can_use_persistent_tma (x .view (1 , x .shape [- 2 ], x .shape [- 1 ]),
502
- w .view (1 , w .shape [- 2 ], w .shape [- 1 ]), gindx , precision_opt ):
503
- pytest .skip ("persistent TMAs not supported for this test" )
504
-
505
503
if mode == "batched" :
506
504
rdata , gindx , sindx = None , None , None
507
- a = swiglu (matmul_ogs (x , w , bias , rdata , gindx , sindx , precision_opt ), swiglu_alpha ,
508
- precision_config = SwiGLUPrecisionConfig (swiglu_limit ))
509
- b = matmul_ogs (
510
- x , w , bias , rdata , gindx , sindx , precision_opt ,
511
- fused_activation = FusedActivation (FnSpecs ("swiglu" , swiglu_fn , ("alpha" , "limit" )), (swiglu_alpha , swiglu_limit ),
512
- 2 ))
505
+
506
+ try :
507
+ a = swiglu (matmul_ogs (x , w , bias , rdata , gindx , sindx , precision_opt ), swiglu_alpha ,
508
+ precision_config = SwiGLUPrecisionConfig (swiglu_limit ))
509
+ b = matmul_ogs (
510
+ x , w , bias , rdata , gindx , sindx , precision_opt ,
511
+ fused_activation = FusedActivation (FnSpecs ("swiglu" , swiglu_fn , ("alpha" , "limit" )),
512
+ (swiglu_alpha , swiglu_limit ), 2 ))
513
+ except opt_flags .InapplicableConstraint :
514
+ pytest .skip ("inapplicable constraint" )
513
515
assert_close (a , b )
0 commit comments