|
13 | 13 | import triton.language as tl
|
14 | 14 | from triton._internal_testing import is_hip_cdna3, is_cuda, is_hip
|
15 | 15 |
|
16 |
| -input_dtypes = ["bfloat16", "float16", "float32", "float64"] |
| 16 | +input_dtypes = ["bfloat16", "float16", "float32"] |
17 | 17 | if is_cuda():
|
18 | 18 | input_dtypes += ["int8", "float8_e5m2"]
|
19 | 19 | cc = torch.cuda.get_device_capability(0)
|
@@ -80,13 +80,11 @@ def matmul_kernel(A, B, C, M, N, K, #
|
80 | 80 | for BLOCK_K in [16, 32, 64] #
|
81 | 81 | for BLOCK_M in [16, 64] #
|
82 | 82 | for BLOCK_N in [16, 64, 128] #
|
83 |
| - for (M, K, N) in [(128, 128, 128), (768, 768, 1024)] # |
| 83 | + for (M, K, N) in [(768, 768, 1024)] # |
84 | 84 | for w in input_dtypes
|
85 | 85 | for x in input_dtypes #
|
86 | 86 | for o in out_dtypes])
|
87 | 87 | def test_cast_matmul(M, K, N, BLOCK_K, BLOCK_M, BLOCK_N, w_dtype, x_dtype, out_dtype, device):
|
88 |
| - if (is_cuda() and torch.cuda.get_device_capability(0)[0] >= 10) and (BLOCK_K, BLOCK_M, BLOCK_N) == (64, 64, 128): |
89 |
| - pytest.skip("skip as they run out of shared memory") |
90 | 88 | if is_hip() and (BLOCK_K, BLOCK_M, BLOCK_N) in ((64, 64, 128), (64, 16, 128)):
|
91 | 89 | pytest.skip("skip as they run out of shared memory")
|
92 | 90 | if x_dtype == w_dtype:
|
|
0 commit comments