@@ -3838,6 +3838,13 @@ def get_test_dot_vdot2_cases():
38383838 (4 , 32 , 32 , 4 , False , False , 'None' , 'ieee' , 'bfloat16' , 'float32' , 1 , None )]
38393839
38403840
3841+ def get_test_small_dots_cases ():
3842+ if not is_cuda ():
3843+ return []
3844+ return [(2 , 4 , 32 , 1 , False , False , 'None' , 'ieee' , 'float16' , 'float32' , 1 , None ),
3845+ (1 , 2 , 32 , 1 , False , False , 'None' , 'ieee' , 'float8e5' , 'float32' , 1 , None )]
3846+
3847+
38413848@pytest .mark .interpreter
38423849@pytest .mark .parametrize (
38433850 "M, N, K, num_warps, col_a, col_b, epilogue, input_precision, in_dtype, out_dtype, kpack, mma_nonk_size" ,
@@ -3851,15 +3858,16 @@ def get_test_dot_vdot2_cases():
38513858 get_test_dot_fp8_output_cases () + \
38523859 get_test_dot_small_k_mfma_cases () + \
38533860 get_test_dot_small_mn_fma_cases () + \
3854- get_test_dot_softmax ())
3861+ get_test_dot_softmax () + \
3862+ get_test_small_dots_cases ())
38553863@pytest .mark .parametrize ("num_ctas" , num_ctas_list )
38563864def test_dot (M , N , K , num_warps , col_a , col_b , epilogue , input_precision , in_dtype , out_dtype , kpack , mma_nonk_size ,
38573865 num_ctas , device ):
38583866 if is_interpreter ():
38593867 if in_dtype == 'bfloat16' :
38603868 pytest .skip ("bfloat16 is not supported in the interpreter" )
38613869 else :
3862- if not is_hip () and ( M < 16 or N < 16 or K < 16 ) :
3870+ if not is_hip () and K < 16 :
38633871 pytest .skip ("small dots are supported only on HIP at the moment" )
38643872 if is_cuda ():
38653873 capability = torch .cuda .get_device_capability ()
@@ -4097,10 +4105,12 @@ def kernel(X, stride_xm, stride_xk, Y, stride_yk, stride_yn, W, stride_wn, strid
40974105 assert 'wgmma.mma_async.sync.aligned' in ptx or \
40984106 'mma.sync.aligned.m16n8k32.row.col.satfinite.s32.s8.s8.s32' in ptx
40994107 elif in_dtype == "float8e5" and out_dtype == tl .float32 :
4100- if capability [0 ] == 9 :
4108+ if capability [0 ] == 9 and M >= 64 and N >= 8 :
41014109 assert 'wgmma.mma_async.sync.aligned.m64n128k32.f32.e5m2.e5m2' in ptx
4110+ elif capability [0 ] >= 8 and M < 64 :
4111+ assert 'mma.sync.aligned.m16n8k16.row.col.f32.f16.f16.f32' in ptx
41024112 elif in_dtype == "float8e4nv" and out_dtype == tl .float32 :
4103- if capability [0 ] == 9 :
4113+ if capability [0 ] == 9 and M >= 64 and N >= 8 :
41044114 assert 'wgmma.mma_async.sync.aligned.m64n128k32.f32.e4m3.e4m3' in ptx
41054115 if is_tcgen5 and epilogue == 'softmax' and M >= 128 :
41064116 # check that there is no shared memory exchange in the softmax
0 commit comments