@@ -193,6 +193,7 @@ class Case:
193
193
Case (300 , 400 , 400 , "ragged" , "bfloat16" , "mxfloat8_e4m3fn" , 8 , 4 , hbm_swizzling = True ),
194
194
Case (300 , 400 , 400 , "batched" , "bfloat16" , "mxfloat8_e5m2" , 32 , 4 ),
195
195
Case (1000 , 700 , 2 , "batched" , "bfloat16" , "mxfloat4_e2m1" , 8 , 2 ),
196
+ Case (1 , 2880 , 2880 , "ragged" , "bfloat16" , "mxfloat4_e2m1" , 128 , 4 ),
196
197
Case (16 , 256 , 256 , "ragged" , "float8_e5m2" , "mxfloat4_e2m1" , 128 , 4 , hbm_swizzling = True ),
197
198
Case (1000 , 704 , 832 , "batched" , "float8_e5m2" , "mxfloat4_e2m1" , 3 , 1 , hbm_swizzling = True ),
198
199
Case (1000 , 704 , 832 , "batched" , "float8_e5m2" , "mxfloat4_e2m1" , 3 , 1 , hbm_swizzling = True ),
@@ -243,6 +244,9 @@ def test_op(m, n, k, split_k, do_gather, do_scatter, fused_scatter, has_y_gammas
243
244
pytest .skip ("float16 x mx not supported with cuda capability >= 10" )
244
245
if "float8" in act_dtype_str and "mx" in weight_dtype_str and torch .cuda .get_device_capability ()[0 ] < 10 :
245
246
pytest .skip ("float8 x mx not supported with cuda capability < 10" )
247
+ if n == 2880 and k == 2880 and torch .cuda .get_device_capability ()[0 ] < 9 :
248
+ pytest .skip ("Not enough memory on A100" )
249
+
246
250
elif is_hip ():
247
251
if "float8" in act_dtype_str and "mx" in weight_dtype_str and not is_hip_cdna4 ():
248
252
pytest .skip ("float8 x mx only supported on CDNA4" )
0 commit comments