@@ -266,6 +266,10 @@ def filter_layouts(layouts):
266266 return [l for l in layouts if is_layout_applicable (l )]
267267
268268
269+ def filter_layout_pairs (pairs ):
270+ return [p for p in pairs if is_layout_applicable (p [0 ]) and is_layout_applicable (p [1 ])]
271+
272+
269273@pytest .mark .interpreter
270274@pytest .mark .parametrize ("dtype_x" , list (dtypes ) + ["bfloat16" ])
271275def test_empty_kernel (dtype_x , device ):
@@ -5770,12 +5774,18 @@ def test_local_load_store_mma(M, N, mma_layout, shared_layout, device, tmp_path:
57705774 MmaLayout ((3 , 0 ), [4 , 1 ], [1 , 1 ], [1 , 1 ], [0 , 1 ], [16 , 64 , 16 ]),
57715775 MmaLayout ((3 , 0 ), [4 , 1 ], [1 , 1 ], [1 , 1 ], [0 , 1 ], [16 , 128 , 16 ]),
57725776 ],
5777+ [
5778+ DpasLayout (repeatCount = 8 , systolic_depth = 8 , execution_size = 8 , ops_per_chan = 1 , threads_per_warp = 32 ,
5779+ warps_per_cta = [4 , 1 ], rep_cluster = [1 , 1 ]),
5780+ DpasLayout (repeatCount = 8 , systolic_depth = 8 , execution_size = 8 , ops_per_chan = 2 , threads_per_warp = 32 ,
5781+ warps_per_cta = [2 , 2 ], rep_cluster = [1 , 1 ]),
5782+ ],
57735783]
57745784
57755785
57765786@pytest .mark .parametrize ("M, N" , [[64 , 1 ], [1 , 64 ], [64 , 64 ], [128 , 128 ], [256 , 256 ]])
57775787@pytest .mark .parametrize ("dtype" , ['float16' ])
5778- @pytest .mark .parametrize ("mma_pair" , filter_layouts (mma_pairs ))
5788+ @pytest .mark .parametrize ("mma_pair" , filter_layout_pairs (mma_pairs ))
57795789def test_convert_mma2mma (M , N , mma_pair , dtype , device , tmp_path : pathlib .Path ):
57805790 src_layout , _ = mma_pair
57815791 num_warps = np .prod (src_layout .warps_per_cta )
0 commit comments