diff --git a/test/test_matmul_cuda.py b/test/test_matmul_cuda.py index a61e827ffffae..566856b18072e 100644 --- a/test/test_matmul_cuda.py +++ b/test/test_matmul_cuda.py @@ -64,6 +64,9 @@ # Protects against includes accidentally setting the default dtype assert torch.get_default_dtype() is torch.float32 +input_dtypes = [torch.float32] +if not torch.version.hip: + input_dtypes += [torch.float16, torch.bfloat16] @contextlib.contextmanager def blas_library_context(backend): @@ -617,8 +620,7 @@ def test_grouped_gemm_compiled(self, op, a_row_major, b_row_major, max_autotune) @onlyCUDA - @skipIfRocm - @parametrize("input_dtype", [torch.float32, torch.float16, torch.bfloat16]) + @parametrize("input_dtype", input_dtypes) @parametrize("M", [1, 32, 64]) @parametrize("N", [1, 32, 64]) @parametrize("K", [1, 32, 64]) @@ -672,8 +674,7 @@ def create_inputs(B=None): @onlyCUDA - @skipIfRocm - @parametrize("input_dtype", [torch.float32, torch.float16, torch.bfloat16]) + @parametrize("input_dtype", input_dtypes) @parametrize("M", [1, 32, 64]) @parametrize("N", [1, 32, 64]) @parametrize("K", [1, 32, 64])