Skip to content

Commit 66795f8

Browse files
Improve test_convert_mma2mma coverage (#3115)
When the parameter set `['mma_pair']` is empty, test case is considered as skipped like below: ``` language/test_core.py::test_convert_mma2mma[mma_pair0-float16-64-1] SKIPPED (got empty parameter set ['mma_pair'], function test_...) language/test_core.py::test_convert_mma2mma[mma_pair0-float16-1-64] SKIPPED (got empty parameter set ['mma_pair'], function test_...) language/test_core.py::test_convert_mma2mma[mma_pair0-float16-64-64] SKIPPED (got empty parameter set ['mma_pair'], function test...) language/test_core.py::test_convert_mma2mma[mma_pair0-float16-128-128] SKIPPED (got empty parameter set ['mma_pair'], function te...) language/test_core.py::test_convert_mma2mma[mma_pair0-float16-256-256] SKIPPED (got empty parameter set ['mma_pair'], function te...) ``` Before: ``` language: passed: 11964, failed: 0, skipped: 7, xfailed: 547, total: 12518, fixme: 0, pass rate (w/o xfailed): 99.94% all: passed: 18664, failed: 0, skipped: 28, xfailed: 1309, total: 20001, fixme: 48, pass rate (w/o xfailed): 99.85% ``` After: ``` language: passed: 11969, failed: 0, skipped: 2, xfailed: 547, total: 12518, fixme: 0, pass rate (w/o xfailed): 99.98% all: passed: 18669, failed: 0, skipped: 23, xfailed: 1309, total: 20001, fixme: 48, pass rate (w/o xfailed): 99.88% ``` Signed-off-by: Whitney Tsang <[email protected]>
1 parent b9da9cc commit 66795f8

File tree

1 file changed

+11
-1
lines changed

1 file changed

+11
-1
lines changed

python/test/unit/language/test_core.py

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -266,6 +266,10 @@ def filter_layouts(layouts):
266266
return [l for l in layouts if is_layout_applicable(l)]
267267

268268

269+
def filter_layout_pairs(pairs):
270+
return [p for p in pairs if is_layout_applicable(p[0]) and is_layout_applicable(p[1])]
271+
272+
269273
@pytest.mark.interpreter
270274
@pytest.mark.parametrize("dtype_x", list(dtypes) + ["bfloat16"])
271275
def test_empty_kernel(dtype_x, device):
@@ -5770,12 +5774,18 @@ def test_local_load_store_mma(M, N, mma_layout, shared_layout, device, tmp_path:
57705774
MmaLayout((3, 0), [4, 1], [1, 1], [1, 1], [0, 1], [16, 64, 16]),
57715775
MmaLayout((3, 0), [4, 1], [1, 1], [1, 1], [0, 1], [16, 128, 16]),
57725776
],
5777+
[
5778+
DpasLayout(repeatCount=8, systolic_depth=8, execution_size=8, ops_per_chan=1, threads_per_warp=32,
5779+
warps_per_cta=[4, 1], rep_cluster=[1, 1]),
5780+
DpasLayout(repeatCount=8, systolic_depth=8, execution_size=8, ops_per_chan=2, threads_per_warp=32,
5781+
warps_per_cta=[2, 2], rep_cluster=[1, 1]),
5782+
],
57735783
]
57745784

57755785

57765786
@pytest.mark.parametrize("M, N", [[64, 1], [1, 64], [64, 64], [128, 128], [256, 256]])
57775787
@pytest.mark.parametrize("dtype", ['float16'])
5778-
@pytest.mark.parametrize("mma_pair", filter_layouts(mma_pairs))
5788+
@pytest.mark.parametrize("mma_pair", filter_layout_pairs(mma_pairs))
57795789
def test_convert_mma2mma(M, N, mma_pair, dtype, device, tmp_path: pathlib.Path):
57805790
src_layout, _ = mma_pair
57815791
num_warps = np.prod(src_layout.warps_per_cta)

0 commit comments

Comments
 (0)