2
2
import pytest
3
3
import torch
4
4
from typing import Union
5
- # benchmarking utilities
6
5
# routing utilities
7
6
from triton_kernels .routing import routing
8
7
# matmul utilities
12
11
from triton_kernels .matmul_ogs import matmul_ogs , matmul_ogs_torch
13
12
# numerics utilities
14
13
from triton_kernels .numerics import InFlexData , OutFlexData
15
- from triton_kernels .numerics_details .mxfp import downcast_to_mxfp , upcast_from_mxfp
14
+ from triton_kernels .numerics_details .mxfp import SwizzlingType , downcast_to_mxfp , upcast_from_mxfp
16
15
# testing utilities
17
16
from triton_kernels .testing import assert_close , compute_actual_scale
18
17
# target-specific utilities
@@ -139,7 +138,7 @@ class Case:
139
138
n_expts_act : int = 1
140
139
n_expt_shards : int = 1
141
140
split_k : int = 1
142
- swizzle_mx_scale : bool = False
141
+ hbm_swizzling : bool = False
143
142
epilogue_subtile : Union [bool , None ] = None
144
143
145
144
@@ -174,25 +173,28 @@ class Case:
174
173
Case (1000 , 700 , 700 , "ragged" , "float16" , "float16" , 8 , 2 , split_k = 9 ),
175
174
# mx types:
176
175
Case (16 , 256 , 256 , "ragged" , "bfloat16" , "mxfloat4_e2m1" , 128 , 4 ),
176
+ Case (16 , 256 , 256 , "ragged" , "bfloat16" , "mxfloat4_e2m1" , 128 , 4 , hbm_swizzling = True ),
177
177
Case (1000 , 700 , 700 , "batched" , "bfloat16" , "mxfloat4_e2m1" , 8 , 2 ),
178
+ Case (1000 , 700 , 700 , "batched" , "bfloat16" , "mxfloat4_e2m1" , 8 , 2 , hbm_swizzling = True ),
178
179
Case (1000 , 700 , 700 , "ragged" , "bfloat16" , "mxfloat4_e2m1" , 8 , 2 , split_k = 9 ),
180
+ Case (1000 , 512 , 256 , "ragged" , "bfloat16" , "mxfloat4_e2m1" , 8 , 2 , split_k = 9 , hbm_swizzling = True ),
179
181
Case (300 , 400 , 400 , "ragged" , "bfloat16" , "mxfloat8_e4m3fn" , 8 , 4 ),
180
- Case (300 , 400 , 400 , "ragged" , "bfloat16" , "mxfloat8_e4m3fn" , 8 , 4 , swizzle_mx_scale = True ),
182
+ Case (300 , 400 , 400 , "ragged" , "bfloat16" , "mxfloat8_e4m3fn" , 8 , 4 , hbm_swizzling = True ),
181
183
Case (300 , 400 , 400 , "batched" , "bfloat16" , "mxfloat8_e5m2" , 32 , 4 ),
182
184
Case (1000 , 700 , 2 , "batched" , "bfloat16" , "mxfloat4_e2m1" , 8 , 2 ),
183
- Case (16 , 256 , 256 , "ragged" , "float8_e5m2" , "mxfloat4_e2m1" , 128 , 4 , swizzle_mx_scale = True ),
184
- Case (1000 , 704 , 800 , "batched" , "float8_e5m2" , "mxfloat4_e2m1" , 3 , 1 , swizzle_mx_scale = True ),
185
- Case (1000 , 704 , 800 , "batched" , "float8_e5m2" , "mxfloat4_e2m1" , 3 , 1 , swizzle_mx_scale = False ),
186
- Case (1000 , 704 , 800 , "ragged" , "float8_e5m2" , "mxfloat4_e2m1" , 8 , 2 , split_k = 9 , swizzle_mx_scale = False ),
187
- Case (1000 , 704 , 800 , "ragged" , "float8_e5m2" , "mxfloat4_e2m1" , 8 , 2 , split_k = 9 , swizzle_mx_scale = True ),
188
- Case (1000 , 704 , 800 , "ragged" , "float8_e5m2" , "mxfloat4_e2m1" , 8 , 2 , swizzle_mx_scale = False ),
189
- Case (1000 , 704 , 800 , "ragged" , "float8_e5m2" , "mxfloat4_e2m1" , 8 , 2 , swizzle_mx_scale = True ),
190
- Case (300 , 400 , 400 , "ragged" , "float8_e5m2" , "mxfloat8_e4m3fn" , 8 , 4 , swizzle_mx_scale = False ),
191
- Case (300 , 400 , 400 , "ragged" , "float8_e5m2" , "mxfloat8_e4m3fn" , 8 , 4 , swizzle_mx_scale = True ),
192
- Case (300 , 400 , 800 , "ragged" , "float8_e5m2" , "mxfloat4_e2m1" , 8 , 4 , swizzle_mx_scale = False ),
193
- Case (300 , 400 , 800 , "ragged" , "float8_e5m2" , "mxfloat4_e2m1" , 8 , 4 , swizzle_mx_scale = True ),
194
- Case (300 , 400 , 400 , "batched" , "float8_e5m2" , "mxfloat8_e4m3fn" , 32 , 4 , swizzle_mx_scale = False ),
195
- Case (300 , 400 , 400 , "batched" , "float8_e5m2" , "mxfloat8_e4m3fn" , 32 , 4 , swizzle_mx_scale = True ),
185
+ Case (16 , 256 , 256 , "ragged" , "float8_e5m2" , "mxfloat4_e2m1" , 128 , 4 , hbm_swizzling = True ),
186
+ Case (1000 , 704 , 800 , "batched" , "float8_e5m2" , "mxfloat4_e2m1" , 3 , 1 , hbm_swizzling = True ),
187
+ Case (1000 , 704 , 800 , "batched" , "float8_e5m2" , "mxfloat4_e2m1" , 3 , 1 ),
188
+ Case (1000 , 704 , 800 , "ragged" , "float8_e5m2" , "mxfloat4_e2m1" , 8 , 2 , split_k = 9 ),
189
+ Case (1000 , 704 , 800 , "ragged" , "float8_e5m2" , "mxfloat4_e2m1" , 8 , 2 , split_k = 9 , hbm_swizzling = True ),
190
+ Case (1000 , 704 , 800 , "ragged" , "float8_e5m2" , "mxfloat4_e2m1" , 8 , 2 ),
191
+ Case (1000 , 704 , 800 , "ragged" , "float8_e5m2" , "mxfloat4_e2m1" , 8 , 2 , hbm_swizzling = True ),
192
+ Case (300 , 400 , 400 , "ragged" , "float8_e5m2" , "mxfloat8_e4m3fn" , 8 , 4 ),
193
+ Case (300 , 400 , 400 , "ragged" , "float8_e5m2" , "mxfloat8_e4m3fn" , 8 , 4 , hbm_swizzling = True ),
194
+ Case (300 , 400 , 800 , "ragged" , "float8_e5m2" , "mxfloat4_e2m1" , 8 , 4 ),
195
+ Case (300 , 400 , 800 , "ragged" , "float8_e5m2" , "mxfloat4_e2m1" , 8 , 4 , hbm_swizzling = True ),
196
+ Case (300 , 400 , 400 , "batched" , "float8_e5m2" , "mxfloat8_e4m3fn" , 32 , 4 ),
197
+ Case (300 , 400 , 400 , "batched" , "float8_e5m2" , "mxfloat8_e4m3fn" , 32 , 4 , hbm_swizzling = True ),
196
198
# AMD
197
199
Case (300 , 400 , 400 , "ragged" , "float8_e4m3fnuz" , "float8_e4m3fnuz" ),
198
200
Case (1000 , 400 , 400 , "ragged" , "float8_e4m3fnuz" , "float8_e4m3fnuz" , 3 , 1 ),
@@ -214,8 +216,8 @@ class Case:
214
216
@pytest .mark .parametrize ("has_y_gammas" , [False , True ])
215
217
@pytest .mark .parametrize ("is_persistent" , [False , True ])
216
218
def test_op (m , n , k , split_k , do_gather , do_scatter , fused_scatter , has_y_gammas , is_persistent , n_expts_tot ,
217
- n_expts_act , n_expt_shards , mode , act_dtype_str , weight_dtype_str , block_m , swizzle_mx_scale ,
218
- epilogue_subtile , device ):
219
+ n_expts_act , n_expt_shards , mode , act_dtype_str , weight_dtype_str , block_m , hbm_swizzling , epilogue_subtile ,
220
+ device ):
219
221
# TODO: remove when Triton FP8 supports proper RTNE
220
222
if "float8" in weight_dtype_str and torch .cuda .get_device_capability ()[0 ] < 9 :
221
223
pytest .skip ("Float8 not tested on A100" )
@@ -229,11 +231,22 @@ def test_op(m, n, k, split_k, do_gather, do_scatter, fused_scatter, has_y_gammas
229
231
pytest .skip ("float8 x mx not supported with cuda capability < 10" )
230
232
if fused_scatter and split_k > 1 :
231
233
pytest .skip ("fused scatter scratchpad not supported with split_k" )
234
+ if hbm_swizzling :
235
+ if is_hip ():
236
+ pytest .skip ("NYI. HBM swizzling just implemented for CUDA." )
237
+ if torch .cuda .get_device_capability ()[0 ] < 9 :
238
+ pytest .skip ("NYI. Ampere swizzling." )
239
+ if torch .cuda .get_device_capability ()[0 ] < 10 :
240
+ if "mxfloat4" not in weight_dtype_str :
241
+ pytest .skip ("NYI. Hopper swizzling just implemented for mxfp4." )
242
+ if k % 64 != 0 or n % 64 != 0 :
243
+ # Automatic padding not implemented for Hopper swizzle
244
+ pytest .skip ("Hopper swizzling acts on a 64x64 tile (4x1 mma tiles)." )
232
245
233
246
torch .manual_seed (0 )
234
247
235
248
block_k = None
236
- if is_persistent and weight_dtype_str .startswith ("mx" ) and not torch .cuda .get_device_capability ()[0 ] >= 10 :
249
+ if is_persistent and weight_dtype_str .startswith ("mx" ) and torch .cuda .get_device_capability ()[0 ] < 10 :
237
250
# Override block_k for testing correctness. The default is temporarily 128 for
238
251
# performance reasons which doesn't work with persistent matmul.
239
252
# TODO: revisit when Triton is better for H100 + MXFP4
@@ -273,12 +286,27 @@ def test_op(m, n, k, split_k, do_gather, do_scatter, fused_scatter, has_y_gammas
273
286
x_ref , w_ref , bias_ref , gs0_ref , gs1_ref = apply_precision (x_tri , w_tri , bias_tri , gs0_tri , gs1_tri , precision_opt )
274
287
275
288
if is_mixed_input :
276
- swizzle_axis = 2 if swizzle_mx_scale else None
289
+ if hbm_swizzling :
290
+ swizzle_axis = 2
291
+ if torch .cuda .get_device_capability ()[0 ] < 10 :
292
+ swizzle_value = SwizzlingType .HOPPER
293
+ swizzle_scale = SwizzlingType .HOPPER
294
+ else :
295
+ swizzle_value = None
296
+ swizzle_scale = SwizzlingType .BLACKWELL
297
+ else :
298
+ swizzle_axis = None
299
+ swizzle_value = None
300
+ swizzle_scale = None
277
301
w_tri , mx_scales_tri , weight_scale_shape = downcast_to_mxfp (w_tri , weight_dtype , axis = 1 ,
278
- swizzle_axis = swizzle_axis )
279
- w_ref = upcast_from_mxfp (w_tri , mx_scales_tri , torch .bfloat16 , axis = 1 , swizzle_axis = swizzle_axis )
280
-
281
- precision_opt .mx_ctx = MicroscalingCtx (weight_scale = mx_scales_tri , swizzle_mx = swizzle_mx_scale ,
302
+ swizzle_axis = swizzle_axis ,
303
+ swizzle_value = swizzle_value ,
304
+ swizzle_scale = swizzle_scale )
305
+ w_ref = upcast_from_mxfp (w_tri , mx_scales_tri , torch .bfloat16 , axis = 1 , swizzle_axis = swizzle_axis ,
306
+ swizzle_value = swizzle_value , swizzle_scale = swizzle_scale )
307
+
308
+ precision_opt .mx_ctx = MicroscalingCtx (weight_scale = mx_scales_tri , swizzle_value = swizzle_value ,
309
+ swizzle_scale = swizzle_scale ,
282
310
actual_weight_scale_shape = weight_scale_shape )
283
311
284
312
if is_persistent and not can_use_persistent_tma (x_tri , w_tri , gindx , precision_opt ):
0 commit comments