@@ -161,6 +161,13 @@ class Case:
161161 ", " .join (f .name for f in fields (Case )),
162162 [
163163 tuple (getattr (case , f .name ) for f in fields (Case )) for case in [
164+ # Zero-sized args:
165+ Case (0 , 5 , 7 , "ragged" , "float16" , "float16" ),
166+ Case (5 , 0 , 7 , "ragged" , "float16" , "float16" ),
167+ Case (5 , 7 , 0 , "ragged" , "float16" , "float16" ),
168+ Case (0 , 5 , 7 , "batched" , "float16" , "float16" ),
169+ Case (5 , 0 , 7 , "batched" , "float16" , "float16" ),
170+ Case (5 , 7 , 0 , "batched" , "float16" , "float16" ),
164171 # Non-mx types:
165172 Case (16 , 256 , 256 , "ragged" , "float16" , "float16" , 128 , 4 ),
166173 Case (16 , 256 , 256 , "ragged" , "float16" , "float16" , 128 , 4 , n_expt_shards = 2 ),
@@ -301,7 +308,7 @@ def test_op(m, n, k, split_k, do_gather, do_scatter, fused_scatter, has_y_gammas
301308 pytest .skip ("Hopper swizzling acts on a 64x64 tile (4x1 mma tiles)." )
302309
303310 # launch metadata for batched / mx types may not work yet.
304- test_launch_metadata = (mode == "ragged" ) and ("mx" not in weight_dtype_str ) and fused_scatter
311+ test_launch_metadata = (mode == "ragged" ) and ("mx" not in weight_dtype_str ) and fused_scatter and m * n * k != 0
305312
306313 torch .manual_seed (0 )
307314
@@ -349,7 +356,7 @@ def test_op(m, n, k, split_k, do_gather, do_scatter, fused_scatter, has_y_gammas
349356 has_y_gammas , requires_grad = test_bwd , device = device )
350357 x_ref , w_ref , bias_ref , gs0_ref , gs1_ref = apply_precision (x_tri , w_tri , bias_tri , gs0_tri , gs1_tri , precision_opt )
351358
352- if w_tri .shape [0 ] == 1 :
359+ if w_tri .shape [0 ] == 1 and mode != "batched" :
353360 # Test the case when weight has dim 2, i.e., shape (K, N).
354361 w_tri = w_tri .squeeze (0 ).detach ().requires_grad_ (test_bwd )
355362 w_ref = w_ref .squeeze (0 ).detach ().requires_grad_ (test_bwd )
0 commit comments