Skip to content

Commit 9856e10

Browse files
dnikolaev-amdpragupta
authored andcommitted
[rocm7.0_internal_testing] fix enabling sparse tests fp16/bf16 for rocm7.0/7.1 (#2239)
Revamped version of #2108 PR to: - enable complex data types for sparse matmul on ROCm - fix sparse addmm/baddbmm on ROCm - fix sparse hipification for ROCm - fix/enable sparse tests on ROCm (~50 tests total for non-fp16/bf16): - enable fp16/bf16 sparse path for rocm7.0 - enable fp16/bf16 sparse tests for rocm7.0/7.1 ``` test_sparse_csr.py::TestSparseCSRCUDA::test_bmm_cuda_* test_sparse.py::TestSparseCUDA::test_sparse_matmul_cuda_* test_sparse_csr.py::TestSparseCSRCUDA::test_mm_cuda_float64 test_sparse_csr.py::TestSparseCSRCUDA::test_addmm_all_sparse_csr_SparseCS* test_sparse_csr.py::TestSparseCSRCUDA::test_addmm_sizes_all_sparse_csr_* test_sparse_csr.py::TestSparseCSRCUDA::test_sparse_addmm_cuda_float16 ``` (cherry picked from commit cc2a69c)
1 parent e57fde7 commit 9856e10

File tree

4 files changed

+30
-3
lines changed

4 files changed

+30
-3
lines changed

aten/src/ATen/native/sparse/cuda/SparseMatMul.cu

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,18 @@
4646
#define IS_CUSPARSE11_AVAILABLE() 0
4747
#endif
4848

49+
#if defined(USE_ROCM) && (ROCM_VERSION >= 70000)
50+
#define HIPSPARSE_FP16_SUPPORT 1
51+
#else
52+
#define HIPSPARSE_FP16_SUPPORT 0
53+
#endif
54+
55+
#if defined(USE_ROCM) && (ROCM_VERSION >= 70100)
56+
#define HIPSPARSE_FP16_BF16_SUPPORT 1
57+
#else
58+
#define HIPSPARSE_FP16_BF16_SUPPORT 0
59+
#endif
60+
4961
#if IS_CUSPARSE11_AVAILABLE()
5062
#include <library_types.h>
5163
#endif

test/test_sparse.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -69,6 +69,12 @@ def _op_supports_any_sparse(op):
6969
) or (not IS_WINDOWS and not TEST_WITH_ROCM)
7070

7171
HIPSPARSE_SPMM_COMPLEX128_SUPPORTED = torch.version.hip and version.parse(torch.version.hip.split("-")[0]) >= version.parse("6.0")
72+
HIPSPARSE_FP16_SUPPORTED = torch.version.hip and version.parse(torch.version.hip.split("-")[0]) >= version.parse("7.0")
73+
HIPSPARSE_BF16_SUPPORTED = torch.version.hip and version.parse(torch.version.hip.split("-")[0]) >= version.parse("7.1")
74+
75+
SPARSE_COMPLEX128_SUPPORTED = CUSPARSE_SPMM_COMPLEX128_SUPPORTED or HIPSPARSE_SPMM_COMPLEX128_SUPPORTED
76+
SPARSE_FLOAT16_SUPPORTED = (SM53OrLater and torch.version.cuda) or (HIPSPARSE_FP16_SUPPORTED)
77+
SPARSE_BFLOAT16_SUPPORTED = (SM80OrLater and torch.version.cuda) or (HIPSPARSE_BF16_SUPPORTED)
7278

7379
def all_sparse_layouts(test_name='layout', include_strided=False):
7480
return parametrize(test_name, [

test/test_sparse_csr.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,8 @@
2525
all_types_and_complex, floating_and_complex_types_and)
2626
from torch.testing._internal.opinfo.definitions.linalg import sample_inputs_linalg_solve
2727
from torch.testing._internal.opinfo.definitions.sparse import validate_sample_input_sparse
28-
from test_sparse import CUSPARSE_SPMM_COMPLEX128_SUPPORTED, HIPSPARSE_SPMM_COMPLEX128_SUPPORTED
28+
from test_sparse import HIPSPARSE_BF16_SUPPORTED, HIPSPARSE_FP16_SUPPORTED, \
29+
SPARSE_FLOAT16_SUPPORTED, SPARSE_BFLOAT16_SUPPORTED, SPARSE_COMPLEX128_SUPPORTED
2930
import operator
3031

3132
if TEST_SCIPY:
@@ -1940,8 +1941,8 @@ def test_shape(d1, d2, d3, nnz, transposed, index_dtype):
19401941

19411942
@dtypes(*floating_and_complex_types())
19421943
@dtypesIfCUDA(*floating_and_complex_types_and(
1943-
*[torch.half] if SM53OrLater and TEST_CUSPARSE_GENERIC else [],
1944-
*[torch.bfloat16] if SM80OrLater and TEST_CUSPARSE_GENERIC else []))
1944+
*[torch.half] if SPARSE_FLOAT16_SUPPORTED else [],
1945+
*[torch.bfloat16] if SPARSE_BFLOAT16_SUPPORTED else []))
19451946
@precisionOverride({torch.bfloat16: 3.5e-2, torch.float16: 1e-2})
19461947
def test_sparse_addmm(self, device, dtype):
19471948
def test_shape(m, n, p, nnz, broadcast, index_dtype, alpha_beta=None):

torch/utils/hipify/cuda_to_hip_mappings.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8588,6 +8588,14 @@
85888588
"CUSPARSE_STATUS_ZERO_PIVOT",
85898589
("HIPSPARSE_STATUS_ZERO_PIVOT", CONV_NUMERIC_LITERAL, API_SPECIAL),
85908590
),
8591+
(
8592+
"CUSPARSE_STATUS_NOT_SUPPORTED",
8593+
("HIPSPARSE_STATUS_NOT_SUPPORTED", CONV_NUMERIC_LITERAL, API_SPECIAL),
8594+
),
8595+
(
8596+
"CUSPARSE_STATUS_INSUFFICIENT_RESOURCES",
8597+
("HIPSPARSE_STATUS_INSUFFICIENT_RESOURCES", CONV_NUMERIC_LITERAL, API_SPECIAL),
8598+
),
85918599
(
85928600
"CUSPARSE_OPERATION_TRANSPOSE",
85938601
("HIPSPARSE_OPERATION_TRANSPOSE", CONV_NUMERIC_LITERAL, API_SPECIAL),

0 commit comments

Comments
 (0)