@@ -3299,7 +3299,7 @@ def convert_fp8_to_fp32(x, device, dtype_str):
32993299 [(128 , 128 , 64 , 4 , False , False , 'chain-dot' , 'ieee' , float8_type , 'float32' , 1 )
33003300 for float8_type in ["float8e5" , "float8e4nv" ]] +
33013301 [(* shape_nw , False , False , epilogue , 'ieee' , in_dtype , out_dtype , 1 )
3302- for shape_nw in [(2 , 2 , 16 , 1 ), (1 , 64 , 64 , 1 ), (64 , 2 , 64 , 2 ), (64 , 64 , 4 , 4 )]
3302+ for shape_nw in [(2 , 2 , 16 , 1 ), (1 , 64 , 64 , 1 ), (64 , 2 , 64 , 2 ), (64 , 64 , 4 , 4 ), ( 8 , 16 , 16 , 1 ) ]
33033303 for epilogue in ['none' , 'trans' , 'add-matrix' , 'add-rows' , 'add-cols' ]
33043304 for in_dtype , out_dtype in [('float16' , 'float16' ), ('float32' , 'float32' )]])
33053305@pytest .mark .parametrize ("num_ctas" , num_ctas_list )
@@ -3308,7 +3308,10 @@ def test_dot(M, N, K, num_warps, col_a, col_b, epilogue, input_precision, in_dty
33083308 if in_dtype == 'bfloat16' :
33093309 pytest .xfail ("bfloat16 is not supported in the interpreter" )
33103310 else :
3311- if not is_hip () and (M < 16 or N < 16 or K < 16 ):
3311+ if is_xpu ():
3312+ if (M < 8 or N < 16 or (K < 16 and in_dtype == 'float16' ) or (K < 8 and in_dtype == 'float32' )):
3313+ pytest .xfail ("XPU: small dots are not supported" )
3314+ elif not is_hip () and (M < 16 or N < 16 or K < 16 ):
33123315 pytest .skip ("small dots are supported only on HIP at the moment" )
33133316 if is_cuda ():
33143317 capability = torch .cuda .get_device_capability ()
@@ -3760,7 +3763,7 @@ def make_finite(x, dtype):
37603763 [(B , num_warps , M , N , K , BLOCK_M , BLOCK_N , in_dtype_str , out_dtype_str )
37613764 for B in [1 , 2 , 8 ]
37623765 for num_warps in [1 , 2 , 4 ]
3763- for BLOCK_M , BLOCK_N in [(1 , 32 ), (32 , 2 ), (8 , 8 )]
3766+ for BLOCK_M , BLOCK_N in [(1 , 32 ), (32 , 2 ), (8 , 8 ), ( 8 , 16 ) ]
37643767 for M , N , K in [(32 , 32 , 32 )]
37653768 for in_dtype_str , out_dtype_str in [('float16' , 'float16' ), ('float32' , 'float32' )]])
37663769def test_dot3d (B , num_warps , M , N , K , BLOCK_M , BLOCK_N , in_dtype_str , out_dtype_str , device ):
@@ -3775,7 +3778,10 @@ def test_dot3d(B, num_warps, M, N, K, BLOCK_M, BLOCK_N, in_dtype_str, out_dtype_
37753778 pytest .skip (f"{ out_dtype_str } has low precision in WMMA dot" )
37763779 else :
37773780 input_precision = "tf32" if (is_cuda () or is_xpu ()) and in_dtype_str == 'float32' else "ieee"
3778- if not is_interpreter () and (BLOCK_M < 16 or BLOCK_N < 16 ):
3781+ if is_xpu ():
3782+ if (BLOCK_M < 8 or BLOCK_N < 16 ):
3783+ pytest .xfail ("XPU: small dots are not supported" )
3784+ elif not is_interpreter () and (BLOCK_M < 16 or BLOCK_N < 16 ):
37793785 pytest .skip ("small dots are supported only on HIP at the moment" )
37803786
37813787 if B == 8 and M == 64 and in_dtype_str == "float32" and out_dtype_str == "float32" :
0 commit comments