|
12 | 12 | from torch.testing._internal.common_cuda import SM53OrLater, SM80OrLater, TEST_CUSPARSE_GENERIC |
13 | 13 | from torch.testing._internal.common_utils import \ |
14 | 14 | (TEST_WITH_TORCHINDUCTOR, TEST_WITH_ROCM, TEST_CUDA_CUDSS, TEST_SCIPY, TEST_NUMPY, TEST_MKL, IS_WINDOWS, TestCase, |
15 | | - run_tests, load_tests, coalescedonoff, parametrize, subtest, skipIfTorchDynamo, skipIfRocm, IS_FBCODE, IS_REMOTE_GPU, |
16 | | - suppress_warnings) |
| 15 | + run_tests, load_tests, coalescedonoff, parametrize, subtest, skipIfTorchDynamo, skipIfRocm, |
| 16 | + skipIfRocmVersionLessThan, IS_FBCODE, IS_REMOTE_GPU, suppress_warnings) |
17 | 17 | from torch.testing._internal.common_device_type import \ |
18 | 18 | (ops, instantiate_device_type_tests, dtypes, OpDTypes, dtypesIfCUDA, onlyCPU, onlyCUDA, skipCUDAIfNoSparseGeneric, |
19 | 19 | precisionOverride, skipMeta, skipCUDAIf, skipCPUIfNoMklSparse, skipCUDAIfRocmVersionLessThan, |
|
26 | 26 | all_types_and_complex, floating_and_complex_types_and) |
27 | 27 | from torch.testing._internal.opinfo.definitions.linalg import sample_inputs_linalg_solve |
28 | 28 | from torch.testing._internal.opinfo.definitions.sparse import validate_sample_input_sparse |
29 | | -from test_sparse import CUSPARSE_SPMM_COMPLEX128_SUPPORTED, HIPSPARSE_SPMM_COMPLEX128_SUPPORTED |
| 29 | +from test_sparse import HIPSPARSE_BF16_SUPPORTED, HIPSPARSE_FP16_SUPPORTED, \ |
| 30 | + SPARSE_FLOAT16_SUPPORTED, SPARSE_BFLOAT16_SUPPORTED, SPARSE_COMPLEX128_SUPPORTED |
30 | 31 | import operator |
31 | 32 |
|
32 | 33 | if TEST_SCIPY: |
@@ -1545,9 +1546,10 @@ def run_test(c, a, a_batched, b, op_b=False, op_out=False, *, dtype=None, device |
1545 | 1546 | run_test(c, a, a_batched, b, op_b, op_out, dtype=dtype, device=device) |
1546 | 1547 |
|
1547 | 1548 | @onlyCUDA |
1548 | | - @unittest.skipIf(TEST_WITH_ROCM, "Only CUDA 11+ is supported") |
| 1549 | + @skipIfRocmVersionLessThan((6, 3)) |
1549 | 1550 | @skipCUDAIfNoSparseGeneric |
1550 | | - @dtypes(torch.float32, torch.float64, torch.complex64, torch.complex128) |
| 1551 | + @dtypes(*floating_and_complex_types_and(*[torch.half] if HIPSPARSE_FP16_SUPPORTED else [], |
| 1552 | + *[torch.bfloat16] if HIPSPARSE_BF16_SUPPORTED else [])) |
1551 | 1553 | def test_bmm(self, device, dtype): |
1552 | 1554 | def run_test(a, a_batched, b, op_b=False, op_out=False, *, dtype=None, device=None): |
1553 | 1555 | b = b.mH if (op_b and a.shape == b.shape) else b |
@@ -1834,7 +1836,7 @@ def run_test(a, b, upper, transpose, unitriangular, op_out): |
1834 | 1836 | run_test(a, b, upper, unitriangular, transpose, op_out) |
1835 | 1837 |
|
1836 | 1838 | @skipCPUIfNoMklSparse |
1837 | | - @unittest.skipIf(TEST_WITH_ROCM, "Only CUDA 11+ is supported") |
| 1839 | + @skipIfRocmVersionLessThan((6, 3)) |
1838 | 1840 | @dtypes(torch.double) |
1839 | 1841 | def test_mm(self, device, dtype): |
1840 | 1842 | def test_shape(di, dj, dk, nnz0=None, nnz1=None): |
@@ -1954,8 +1956,8 @@ def test_shape(d1, d2, d3, nnz, transposed, index_dtype): |
1954 | 1956 |
|
1955 | 1957 | @dtypes(*floating_and_complex_types()) |
1956 | 1958 | @dtypesIfCUDA(*floating_and_complex_types_and( |
1957 | | - *[torch.half] if SM53OrLater and TEST_CUSPARSE_GENERIC else [], |
1958 | | - *[torch.bfloat16] if SM80OrLater and TEST_CUSPARSE_GENERIC else [])) |
| 1959 | + *[torch.half] if SPARSE_FLOAT16_SUPPORTED else [], |
| 1960 | + *[torch.bfloat16] if SPARSE_BFLOAT16_SUPPORTED else [])) |
1959 | 1961 | @precisionOverride({torch.bfloat16: 3.5e-2, torch.float16: 1e-2}) |
1960 | 1962 | def test_sparse_addmm(self, device, dtype): |
1961 | 1963 | def test_shape(m, n, p, nnz, broadcast, index_dtype, alpha_beta=None): |
@@ -1984,18 +1986,15 @@ def test_shape(m, n, p, nnz, broadcast, index_dtype, alpha_beta=None): |
1984 | 1986 | test_shape(7, 8, 9, 20, True, index_dtype, (1, 1)) |
1985 | 1987 |
|
1986 | 1988 | @skipCPUIfNoMklSparse |
| 1989 | + @skipIfRocmVersionLessThan((6, 3)) |
1987 | 1990 | @dtypes(*floating_and_complex_types()) |
1988 | 1991 | @precisionOverride({torch.double: 1e-8, torch.float: 1e-4, torch.bfloat16: 0.6, |
1989 | 1992 | torch.half: 1e-1, torch.cfloat: 1e-4, torch.cdouble: 1e-8}) |
1990 | 1993 | @dtypesIfCUDA(*floating_types_and(torch.complex64, |
1991 | | - *[torch.bfloat16] if SM80OrLater else [], |
1992 | | - *[torch.half] if SM53OrLater else [], |
1993 | | - *[torch.complex128] if CUSPARSE_SPMM_COMPLEX128_SUPPORTED else [])) |
| 1994 | + *[torch.bfloat16] if SPARSE_BFLOAT16_SUPPORTED else [], |
| 1995 | + *[torch.half] if SPARSE_FLOAT16_SUPPORTED else [], |
| 1996 | + *[torch.complex128] if SPARSE_COMPLEX128_SUPPORTED else [])) |
1994 | 1997 | @sparse_compressed_nonblock_layouts() |
1995 | | - @skipCUDAIf( |
1996 | | - not _check_cusparse_spgemm_available(), |
1997 | | - "cuSparse Generic API SpGEMM is not available" |
1998 | | - ) |
1999 | 1998 | def test_addmm_all_sparse_csr(self, device, dtype, layout): |
2000 | 1999 | M = torch.randn(10, 25, device=device).to(dtype) |
2001 | 2000 | m1 = torch.randn(10, 50, device=device).to(dtype) |
@@ -2066,16 +2065,12 @@ def maybe_transpose(cond, m): |
2066 | 2065 | @skipCPUIfNoMklSparse |
2067 | 2066 | @dtypes(*floating_and_complex_types()) |
2068 | 2067 | @dtypesIfCUDA(*floating_types_and(torch.complex64, |
2069 | | - *[torch.bfloat16] if SM80OrLater else [], |
2070 | | - *[torch.half] if SM53OrLater else [], |
2071 | | - *[torch.complex128] |
2072 | | - if CUSPARSE_SPMM_COMPLEX128_SUPPORTED or HIPSPARSE_SPMM_COMPLEX128_SUPPORTED |
2073 | | - else [])) |
| 2068 | + *[torch.bfloat16] if SPARSE_BFLOAT16_SUPPORTED else [], |
| 2069 | + *[torch.half] if SPARSE_FLOAT16_SUPPORTED else [], |
| 2070 | + *[torch.complex128] if SPARSE_COMPLEX128_SUPPORTED else [])) |
2074 | 2071 | @precisionOverride({torch.double: 1e-8, torch.float: 1e-4, torch.bfloat16: 0.6, |
2075 | 2072 | torch.half: 1e-1, torch.cfloat: 1e-4, torch.cdouble: 1e-8}) |
2076 | 2073 | def test_addmm_sizes_all_sparse_csr(self, device, dtype, m, n, k): |
2077 | | - if (TEST_WITH_ROCM and k != 0 and n != 0 and m != 0): |
2078 | | - self.skipTest("Skipped on ROCm") |
2079 | 2074 | M = torch.randn(n, m, device=device).to(dtype) |
2080 | 2075 | m1 = torch.randn(n, k, device=device).to(dtype) |
2081 | 2076 | m2 = torch.randn(k, m, device=device).to(dtype) |
|
0 commit comments