Skip to content

Commit 158ddaa

Browse files
authored
fix cublas.h (#76137)
* fix cublas.h * fix
1 parent 5f173c7 commit 158ddaa

File tree

3 files changed

+82
-117
lines changed

3 files changed

+82
-117
lines changed

paddle/phi/backends/dynload/cublas.h

Lines changed: 11 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -113,40 +113,19 @@ extern void *cublas_dso_handle;
113113
__macro(cublasZdotc_v2); \
114114
__macro(cublasCdotu_v2); \
115115
__macro(cublasZdotu_v2); \
116-
__macro(cublasDotEx);
117-
118-
CUBLAS_BLAS_ROUTINE_EACH(DECLARE_DYNAMIC_LOAD_CUBLAS_WRAP)
119-
120-
// APIs available after CUDA 8.0
121-
#if CUDA_VERSION >= 8000
122-
#define CUBLAS_BLAS_ROUTINE_EACH_R2(__macro) \
123-
__macro(cublasGemmEx); \
124-
__macro(cublasSgemmStridedBatched); \
125-
__macro(cublasDgemmStridedBatched); \
126-
__macro(cublasCgemmStridedBatched); \
127-
__macro(cublasZgemmStridedBatched); \
128-
__macro(cublasHgemmStridedBatched);
129-
130-
CUBLAS_BLAS_ROUTINE_EACH_R2(DECLARE_DYNAMIC_LOAD_CUBLAS_WRAP)
131-
#endif
132-
133-
// APIs available after CUDA 9.0
134-
#if CUDA_VERSION >= 9000
135-
#define CUBLAS_BLAS_ROUTINE_EACH_R3(__macro) \
136-
__macro(cublasSetMathMode); \
137-
__macro(cublasGetMathMode);
138-
139-
CUBLAS_BLAS_ROUTINE_EACH_R3(DECLARE_DYNAMIC_LOAD_CUBLAS_WRAP)
140-
#endif
141-
142-
// APIs available after CUDA 9.1
143-
#if CUDA_VERSION >= 9010
144-
#define CUBLAS_BLAS_ROUTINE_EACH_R4(__macro) \
145-
__macro(cublasGemmBatchedEx); \
116+
__macro(cublasDotEx); \
117+
__macro(cublasGemmEx); \
118+
__macro(cublasSgemmStridedBatched); \
119+
__macro(cublasDgemmStridedBatched); \
120+
__macro(cublasCgemmStridedBatched); \
121+
__macro(cublasZgemmStridedBatched); \
122+
__macro(cublasHgemmStridedBatched); \
123+
__macro(cublasSetMathMode); \
124+
__macro(cublasGetMathMode); \
125+
__macro(cublasGemmBatchedEx); \
146126
__macro(cublasGemmStridedBatchedEx);
147127

148-
CUBLAS_BLAS_ROUTINE_EACH_R4(DECLARE_DYNAMIC_LOAD_CUBLAS_WRAP)
149-
#endif
128+
CUBLAS_BLAS_ROUTINE_EACH(DECLARE_DYNAMIC_LOAD_CUBLAS_WRAP)
150129

151130
#if CUDA_VERSION >= 12030 && defined(__linux__)
152131
#define CUBLAS_BLAS_ROUTINE_EACH_R5(__macro) \

paddle/phi/backends/dynload/cusolver.h

Lines changed: 71 additions & 83 deletions
Original file line numberDiff line numberDiff line change
@@ -40,87 +40,76 @@ extern void *cusolver_dso_handle;
4040
}; \
4141
extern DynLoad__##__name __name
4242

43-
#define CUSOLVER_ROUTINE_EACH(__macro) \
44-
__macro(cusolverDnCreate); \
45-
__macro(cusolverDnDestroy); \
46-
__macro(cusolverDnSetStream); \
47-
__macro(cusolverDnSpotrf_bufferSize); \
48-
__macro(cusolverDnDpotrf_bufferSize); \
49-
__macro(cusolverDnXpotrf_bufferSize); \
50-
__macro(cusolverDnSpotrf); \
51-
__macro(cusolverDnDpotrf); \
52-
__macro(cusolverDnXpotrf); \
53-
__macro(cusolverDnSpotrs); \
54-
__macro(cusolverDnDpotrs); \
55-
__macro(cusolverDnCpotrs); \
56-
__macro(cusolverDnZpotrs); \
57-
__macro(cusolverDnSsyevd_bufferSize); \
58-
__macro(cusolverDnDsyevd_bufferSize); \
59-
__macro(cusolverDnCheevd_bufferSize); \
60-
__macro(cusolverDnZheevd_bufferSize); \
61-
__macro(cusolverDnSsyevd); \
62-
__macro(cusolverDnDsyevd); \
63-
__macro(cusolverDnCheevd); \
64-
__macro(cusolverDnZheevd);
65-
66-
CUSOLVER_ROUTINE_EACH(DECLARE_DYNAMIC_LOAD_CUSOLVER_WRAP);
67-
68-
#if CUDA_VERSION >= 9020
69-
#define CUSOLVER_ROUTINE_EACH_R1(__macro) \
70-
__macro(cusolverDnSpotrfBatched); \
71-
__macro(cusolverDnDpotrfBatched); \
72-
__macro(cusolverDnSpotrsBatched); \
73-
__macro(cusolverDnDpotrsBatched); \
74-
__macro(cusolverDnSgetrf_bufferSize); \
75-
__macro(cusolverDnDgetrf_bufferSize); \
76-
__macro(cusolverDnCgetrf_bufferSize); \
77-
__macro(cusolverDnZgetrf_bufferSize); \
78-
__macro(cusolverDnSgeqrf_bufferSize); \
79-
__macro(cusolverDnDgeqrf_bufferSize); \
80-
__macro(cusolverDnCgeqrf_bufferSize); \
81-
__macro(cusolverDnZgeqrf_bufferSize); \
82-
__macro(cusolverDnXgeqrf_bufferSize); \
83-
__macro(cusolverDnSorgqr_bufferSize); \
84-
__macro(cusolverDnDorgqr_bufferSize); \
85-
__macro(cusolverDnSormqr_bufferSize); \
86-
__macro(cusolverDnDormqr_bufferSize); \
87-
__macro(cusolverDnCungqr_bufferSize); \
88-
__macro(cusolverDnZungqr_bufferSize); \
89-
__macro(cusolverDnDestroyGesvdjInfo); \
90-
__macro(cusolverDnCreateGesvdjInfo); \
91-
__macro(cusolverDnSgesvdj_bufferSize); \
92-
__macro(cusolverDnDgesvdj_bufferSize); \
93-
__macro(cusolverDnCgesvdj_bufferSize); \
94-
__macro(cusolverDnZgesvdj_bufferSize); \
95-
__macro(cusolverDnSgesvdj); \
96-
__macro(cusolverDnDgesvdj); \
97-
__macro(cusolverDnCgesvdj); \
98-
__macro(cusolverDnZgesvdj); \
99-
__macro(cusolverDnSgetrf); \
100-
__macro(cusolverDnSgetrs); \
101-
__macro(cusolverDnDgetrs); \
102-
__macro(cusolverDnCgetrs); \
103-
__macro(cusolverDnZgetrs); \
104-
__macro(cusolverDnDgetrf); \
105-
__macro(cusolverDnCgetrf); \
106-
__macro(cusolverDnZgetrf); \
107-
__macro(cusolverDnSgeqrf); \
108-
__macro(cusolverDnDgeqrf); \
109-
__macro(cusolverDnCgeqrf); \
110-
__macro(cusolverDnZgeqrf); \
111-
__macro(cusolverDnXgeqrf); \
112-
__macro(cusolverDnSorgqr); \
113-
__macro(cusolverDnDorgqr); \
114-
__macro(cusolverDnSormqr); \
115-
__macro(cusolverDnDormqr); \
116-
__macro(cusolverDnCungqr); \
117-
__macro(cusolverDnZungqr);
118-
119-
CUSOLVER_ROUTINE_EACH_R1(DECLARE_DYNAMIC_LOAD_CUSOLVER_WRAP)
120-
#endif
121-
122-
#if CUDA_VERSION >= 9020
123-
#define CUSOLVER_ROUTINE_EACH_R2(__macro) \
43+
#define CUSOLVER_ROUTINE_EACH(__macro) \
44+
__macro(cusolverDnCreate); \
45+
__macro(cusolverDnDestroy); \
46+
__macro(cusolverDnSetStream); \
47+
__macro(cusolverDnSpotrf_bufferSize); \
48+
__macro(cusolverDnDpotrf_bufferSize); \
49+
__macro(cusolverDnXpotrf_bufferSize); \
50+
__macro(cusolverDnSpotrf); \
51+
__macro(cusolverDnDpotrf); \
52+
__macro(cusolverDnXpotrf); \
53+
__macro(cusolverDnSpotrs); \
54+
__macro(cusolverDnDpotrs); \
55+
__macro(cusolverDnCpotrs); \
56+
__macro(cusolverDnZpotrs); \
57+
__macro(cusolverDnSsyevd_bufferSize); \
58+
__macro(cusolverDnDsyevd_bufferSize); \
59+
__macro(cusolverDnCheevd_bufferSize); \
60+
__macro(cusolverDnZheevd_bufferSize); \
61+
__macro(cusolverDnSsyevd); \
62+
__macro(cusolverDnDsyevd); \
63+
__macro(cusolverDnCheevd); \
64+
__macro(cusolverDnZheevd); \
65+
__macro(cusolverDnSpotrfBatched); \
66+
__macro(cusolverDnDpotrfBatched); \
67+
__macro(cusolverDnSpotrsBatched); \
68+
__macro(cusolverDnDpotrsBatched); \
69+
__macro(cusolverDnSgetrf_bufferSize); \
70+
__macro(cusolverDnDgetrf_bufferSize); \
71+
__macro(cusolverDnCgetrf_bufferSize); \
72+
__macro(cusolverDnZgetrf_bufferSize); \
73+
__macro(cusolverDnSgeqrf_bufferSize); \
74+
__macro(cusolverDnDgeqrf_bufferSize); \
75+
__macro(cusolverDnCgeqrf_bufferSize); \
76+
__macro(cusolverDnZgeqrf_bufferSize); \
77+
__macro(cusolverDnXgeqrf_bufferSize); \
78+
__macro(cusolverDnSorgqr_bufferSize); \
79+
__macro(cusolverDnDorgqr_bufferSize); \
80+
__macro(cusolverDnSormqr_bufferSize); \
81+
__macro(cusolverDnDormqr_bufferSize); \
82+
__macro(cusolverDnCungqr_bufferSize); \
83+
__macro(cusolverDnZungqr_bufferSize); \
84+
__macro(cusolverDnDestroyGesvdjInfo); \
85+
__macro(cusolverDnCreateGesvdjInfo); \
86+
__macro(cusolverDnSgesvdj_bufferSize); \
87+
__macro(cusolverDnDgesvdj_bufferSize); \
88+
__macro(cusolverDnCgesvdj_bufferSize); \
89+
__macro(cusolverDnZgesvdj_bufferSize); \
90+
__macro(cusolverDnSgesvdj); \
91+
__macro(cusolverDnDgesvdj); \
92+
__macro(cusolverDnCgesvdj); \
93+
__macro(cusolverDnZgesvdj); \
94+
__macro(cusolverDnSgetrf); \
95+
__macro(cusolverDnSgetrs); \
96+
__macro(cusolverDnDgetrs); \
97+
__macro(cusolverDnCgetrs); \
98+
__macro(cusolverDnZgetrs); \
99+
__macro(cusolverDnDgetrf); \
100+
__macro(cusolverDnCgetrf); \
101+
__macro(cusolverDnZgetrf); \
102+
__macro(cusolverDnSgeqrf); \
103+
__macro(cusolverDnDgeqrf); \
104+
__macro(cusolverDnCgeqrf); \
105+
__macro(cusolverDnZgeqrf); \
106+
__macro(cusolverDnXgeqrf); \
107+
__macro(cusolverDnSorgqr); \
108+
__macro(cusolverDnDorgqr); \
109+
__macro(cusolverDnSormqr); \
110+
__macro(cusolverDnDormqr); \
111+
__macro(cusolverDnCungqr); \
112+
__macro(cusolverDnZungqr); \
124113
__macro(cusolverDnCreateSyevjInfo); \
125114
__macro(cusolverDnCreateParams); \
126115
__macro(cusolverDnDestroyParams); \
@@ -143,8 +132,7 @@ CUSOLVER_ROUTINE_EACH_R1(DECLARE_DYNAMIC_LOAD_CUSOLVER_WRAP)
143132
__macro(cusolverDnCheevjBatched); \
144133
__macro(cusolverDnZheevjBatched);
145134

146-
CUSOLVER_ROUTINE_EACH_R2(DECLARE_DYNAMIC_LOAD_CUSOLVER_WRAP)
147-
#endif
135+
CUSOLVER_ROUTINE_EACH(DECLARE_DYNAMIC_LOAD_CUSOLVER_WRAP);
148136

149137
#undef DECLARE_DYNAMIC_LOAD_CUSOLVER_WRAP
150138
} // namespace dynload

paddle/phi/backends/dynload/cusparse.h

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,6 @@ extern void *cusparse_dso_handle;
4141
extern DynLoad__##__name __name
4242

4343
#if defined(PADDLE_WITH_CUDA)
44-
#if CUDA_VERSION >= 11000
4544
#define CUSPARSE_ROUTINE_EACH(__macro) \
4645
__macro(cusparseCreate); \
4746
__macro(cusparseSetStream); \
@@ -71,7 +70,6 @@ extern void *cusparse_dso_handle;
7170
__macro(cusparseSpGEMM_destroyDescr);
7271

7372
CUSPARSE_ROUTINE_EACH(DECLARE_DYNAMIC_LOAD_CUSPARSE_WRAP)
74-
#endif
7573

7674
#if CUDA_VERSION >= 11030
7775
#define CUSPARSE_ROUTINE_EACH_R2(__macro) \

0 commit comments

Comments
 (0)