Skip to content

Commit 4c4709d

Browse files
Improve test_dot[3d] coverage (#3111)
Before: ``` language: passed: 11936, failed: 0, skipped: 101, xfailed: 453, total: 12490, fixme: 0, pass rate (w/o xfailed): 99.16% all: passed: 18608, failed: 0, skipped: 122, xfailed: 1215, total: 19945, fixme: 48, pass rate (w/o xfailed): 99.35% ``` After: ``` language: passed: 11964, failed: 0, skipped: 7, xfailed: 547, total: 12518, fixme: 0, pass rate (w/o xfailed): 99.94% all: passed: 18664, failed: 0, skipped: 28, xfailed: 1309, total: 20001, fixme: 48, pass rate (w/o xfailed): 99.85% ``` Signed-off-by: Whitney Tsang <[email protected]>
1 parent b59bb9a commit 4c4709d

File tree

2 files changed

+13
-4
lines changed

2 files changed

+13
-4
lines changed

python/test/unit/language/test_core.py

Lines changed: 10 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -3299,7 +3299,7 @@ def convert_fp8_to_fp32(x, device, dtype_str):
32993299
[(128, 128, 64, 4, False, False, 'chain-dot', 'ieee', float8_type, 'float32', 1)
33003300
for float8_type in ["float8e5", "float8e4nv"]] +
33013301
[(*shape_nw, False, False, epilogue, 'ieee', in_dtype, out_dtype, 1)
3302-
for shape_nw in [(2, 2, 16, 1), (1, 64, 64, 1), (64, 2, 64, 2), (64, 64, 4, 4)]
3302+
for shape_nw in [(2, 2, 16, 1), (1, 64, 64, 1), (64, 2, 64, 2), (64, 64, 4, 4), (8, 16, 16, 1)]
33033303
for epilogue in ['none', 'trans', 'add-matrix', 'add-rows', 'add-cols']
33043304
for in_dtype, out_dtype in [('float16', 'float16'), ('float32', 'float32')]])
33053305
@pytest.mark.parametrize("num_ctas", num_ctas_list)
@@ -3308,7 +3308,10 @@ def test_dot(M, N, K, num_warps, col_a, col_b, epilogue, input_precision, in_dty
33083308
if in_dtype == 'bfloat16':
33093309
pytest.xfail("bfloat16 is not supported in the interpreter")
33103310
else:
3311-
if not is_hip() and (M < 16 or N < 16 or K < 16):
3311+
if is_xpu():
3312+
if (M < 8 or N < 16 or (K < 16 and in_dtype == 'float16') or (K < 8 and in_dtype == 'float32')):
3313+
pytest.xfail("XPU: small dots are not supported")
3314+
elif not is_hip() and (M < 16 or N < 16 or K < 16):
33123315
pytest.skip("small dots are supported only on HIP at the moment")
33133316
if is_cuda():
33143317
capability = torch.cuda.get_device_capability()
@@ -3760,7 +3763,7 @@ def make_finite(x, dtype):
37603763
[(B, num_warps, M, N, K, BLOCK_M, BLOCK_N, in_dtype_str, out_dtype_str)
37613764
for B in [1, 2, 8]
37623765
for num_warps in [1, 2, 4]
3763-
for BLOCK_M, BLOCK_N in [(1, 32), (32, 2), (8, 8)]
3766+
for BLOCK_M, BLOCK_N in [(1, 32), (32, 2), (8, 8), (8, 16)]
37643767
for M, N, K in [(32, 32, 32)]
37653768
for in_dtype_str, out_dtype_str in [('float16', 'float16'), ('float32', 'float32')]])
37663769
def test_dot3d(B, num_warps, M, N, K, BLOCK_M, BLOCK_N, in_dtype_str, out_dtype_str, device):
@@ -3775,7 +3778,10 @@ def test_dot3d(B, num_warps, M, N, K, BLOCK_M, BLOCK_N, in_dtype_str, out_dtype_
37753778
pytest.skip(f"{out_dtype_str} has low precision in WMMA dot")
37763779
else:
37773780
input_precision = "tf32" if (is_cuda() or is_xpu()) and in_dtype_str == 'float32' else "ieee"
3778-
if not is_interpreter() and (BLOCK_M < 16 or BLOCK_N < 16):
3781+
if is_xpu():
3782+
if (BLOCK_M < 8 or BLOCK_N < 16):
3783+
pytest.xfail("XPU: small dots are not supported")
3784+
elif not is_interpreter() and (BLOCK_M < 16 or BLOCK_N < 16):
37793785
pytest.skip("small dots are supported only on HIP at the moment")
37803786

37813787
if B == 8 and M == 64 and in_dtype_str == "float32" and out_dtype_str == "float32":

scripts/skiplist/lts/language.txt

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -274,6 +274,9 @@ test/unit/language/test_core.py::test_dot3d[8-8-64-64-64-32-32-float16-float16]
274274
test/unit/language/test_core.py::test_dot3d[8-8-64-64-64-32-32-float16-float32]
275275
test/unit/language/test_core.py::test_dot3d[8-8-64-64-64-32-32-float32-float32]
276276
test/unit/language/test_core.py::test_dot3d[8-8-64-64-64-32-32-int8-int8]
277+
test/unit/language/test_core.py::test_dot3d[8-2-32-32-32-8-16-float16-float16]
278+
test/unit/language/test_core.py::test_dot3d[8-1-32-32-32-8-16-float16-float16]
279+
test/unit/language/test_core.py::test_dot3d[8-4-32-32-32-8-16-float16-float16]
277280
test/unit/language/test_core.py::test_scaled_dot[32-32-64-True-False-False-e2m1-bf16-4-16-1]
278281
test/unit/language/test_core.py::test_scaled_dot[32-32-64-True-False-True-e4m3-e4m3-4-16-1]
279282
test/unit/language/test_core.py::test_scaled_dot[32-32-128-True-True-False-e5m2-bf16-4-16-1]

0 commit comments

Comments
 (0)