Skip to content

Commit 8a6dfa5

Browse files
authored
[CI] Tweak test config (#7095)
1 parent 861f963 commit 8a6dfa5

File tree

1 file changed

+2
-4
lines changed

1 file changed

+2
-4
lines changed

python/test/regression/test_cast_matmul.py

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313
import triton.language as tl
1414
from triton._internal_testing import is_hip_cdna3, is_cuda, is_hip
1515

16-
input_dtypes = ["bfloat16", "float16", "float32", "float64"]
16+
input_dtypes = ["bfloat16", "float16", "float32"]
1717
if is_cuda():
1818
input_dtypes += ["int8", "float8_e5m2"]
1919
cc = torch.cuda.get_device_capability(0)
@@ -80,13 +80,11 @@ def matmul_kernel(A, B, C, M, N, K, #
8080
for BLOCK_K in [16, 32, 64] #
8181
for BLOCK_M in [16, 64] #
8282
for BLOCK_N in [16, 64, 128] #
83-
for (M, K, N) in [(128, 128, 128), (768, 768, 1024)] #
83+
for (M, K, N) in [(768, 768, 1024)] #
8484
for w in input_dtypes
8585
for x in input_dtypes #
8686
for o in out_dtypes])
8787
def test_cast_matmul(M, K, N, BLOCK_K, BLOCK_M, BLOCK_N, w_dtype, x_dtype, out_dtype, device):
88-
if (is_cuda() and torch.cuda.get_device_capability(0)[0] >= 10) and (BLOCK_K, BLOCK_M, BLOCK_N) == (64, 64, 128):
89-
pytest.skip("skip as they run out of shared memory")
9088
if is_hip() and (BLOCK_K, BLOCK_M, BLOCK_N) in ((64, 64, 128), (64, 16, 128)):
9189
pytest.skip("skip as they run out of shared memory")
9290
if x_dtype == w_dtype:

0 commit comments

Comments
 (0)