@@ -31,23 +31,24 @@ template <>
31
31
struct CUBlas <float > {
32
32
template <typename ... ARGS>
33
33
static void GEMM (ARGS... args) {
34
- PADDLE_ENFORCE (platform::dynload::cublasSgemm (args...));
34
+ PADDLE_ENFORCE_CUDA_SUCCESS (platform::dynload::cublasSgemm (args...));
35
35
}
36
36
37
37
template <typename ... ARGS>
38
38
static void AXPY (ARGS... args) {
39
- PADDLE_ENFORCE (platform::dynload::cublasSaxpy (args...));
39
+ PADDLE_ENFORCE_CUDA_SUCCESS (platform::dynload::cublasSaxpy (args...));
40
40
}
41
41
42
42
template <typename ... ARGS>
43
43
static void GEMV (ARGS... args) {
44
- PADDLE_ENFORCE (platform::dynload::cublasSgemv (args...));
44
+ PADDLE_ENFORCE_CUDA_SUCCESS (platform::dynload::cublasSgemv (args...));
45
45
}
46
46
47
47
template <typename ... ARGS>
48
48
static void GEMM_STRIDED_BATCH (ARGS... args) {
49
49
#if CUDA_VERSION >= 8000
50
- PADDLE_ENFORCE (platform::dynload::cublasSgemmStridedBatched (args...));
50
+ PADDLE_ENFORCE_CUDA_SUCCESS (
51
+ platform::dynload::cublasSgemmStridedBatched (args...));
51
52
#else
52
53
PADDLE_THROW (" SgemmStridedBatched is not supported on cuda <= 7.5" );
53
54
#endif
@@ -69,7 +70,7 @@ struct CUBlas<float> {
69
70
VLOG (5 ) << " use_tensor_op_math: "
70
71
<< (dev_ctx->tensor_core_available () ? " True" : " False" );
71
72
dev_ctx->TensorCoreCublasCallIfAvailable ([&](cublasHandle_t handle) {
72
- PADDLE_ENFORCE (platform::dynload::cublasSgemmEx (
73
+ PADDLE_ENFORCE_CUDA_SUCCESS (platform::dynload::cublasSgemmEx (
73
74
handle, transa, transb, m, n, k, alpha, A, Atype, lda, B, Btype, ldb,
74
75
beta, C, Ctype, ldc));
75
76
});
@@ -83,23 +84,24 @@ template <>
83
84
struct CUBlas <double > {
84
85
template <typename ... ARGS>
85
86
static void GEMM (ARGS... args) {
86
- PADDLE_ENFORCE (platform::dynload::cublasDgemm (args...));
87
+ PADDLE_ENFORCE_CUDA_SUCCESS (platform::dynload::cublasDgemm (args...));
87
88
}
88
89
89
90
template <typename ... ARGS>
90
91
static void AXPY (ARGS... args) {
91
- PADDLE_ENFORCE (platform::dynload::cublasDaxpy (args...));
92
+ PADDLE_ENFORCE_CUDA_SUCCESS (platform::dynload::cublasDaxpy (args...));
92
93
}
93
94
94
95
template <typename ... ARGS>
95
96
static void GEMV (ARGS... args) {
96
- PADDLE_ENFORCE (platform::dynload::cublasDgemv (args...));
97
+ PADDLE_ENFORCE_CUDA_SUCCESS (platform::dynload::cublasDgemv (args...));
97
98
}
98
99
99
100
template <typename ... ARGS>
100
101
static void GEMM_STRIDED_BATCH (ARGS... args) {
101
102
#if CUDA_VERSION >= 8000
102
- PADDLE_ENFORCE (platform::dynload::cublasDgemmStridedBatched (args...));
103
+ PADDLE_ENFORCE_CUDA_SUCCESS (
104
+ platform::dynload::cublasDgemmStridedBatched (args...));
103
105
#else
104
106
PADDLE_THROW (" DgemmStridedBatched is not supported on cuda <= 7.5" );
105
107
#endif
@@ -120,7 +122,7 @@ struct CUBlas<platform::float16> {
120
122
const float16 *alpha, const float16 *A, int lda,
121
123
const float16 *B, int ldb, const float16 *beta, float16 *C,
122
124
int ldc) {
123
- PADDLE_ENFORCE (
125
+ PADDLE_ENFORCE_CUDA_SUCCESS (
124
126
platform::dynload::cublasHgemm (handle, transa, transb, m, n, k,
125
127
reinterpret_cast <const __half *>(alpha),
126
128
reinterpret_cast <const __half *>(A), lda,
@@ -140,7 +142,7 @@ struct CUBlas<platform::float16> {
140
142
long long int strideC, // NOLINT
141
143
int batchCount) {
142
144
#if CUDA_VERSION >= 8000
143
- PADDLE_ENFORCE (platform::dynload::cublasHgemmStridedBatched (
145
+ PADDLE_ENFORCE_CUDA_SUCCESS (platform::dynload::cublasHgemmStridedBatched (
144
146
handle, transa, transb, m, n, k,
145
147
reinterpret_cast <const __half *>(alpha),
146
148
reinterpret_cast <const __half *>(A), lda, strideA,
@@ -174,7 +176,7 @@ struct CUBlas<platform::float16> {
174
176
#endif // CUDA_VERSION >= 9000
175
177
176
178
dev_ctx->TensorCoreCublasCallIfAvailable ([&](cublasHandle_t handle) {
177
- PADDLE_ENFORCE (platform::dynload::cublasGemmEx (
179
+ PADDLE_ENFORCE_CUDA_SUCCESS (platform::dynload::cublasGemmEx (
178
180
handle, transa, transb, m, n, k, alpha, A, Atype, lda, B, Btype, ldb,
179
181
beta, C, Ctype, ldc, computeType, algo));
180
182
});
@@ -356,7 +358,7 @@ void Blas<platform::CUDADeviceContext>::BatchedGEMM(
356
358
<< (use_tensor_op_math ? " True" : " False" );
357
359
358
360
context_.TensorCoreCublasCallIfAvailable ([&](cublasHandle_t handle) {
359
- PADDLE_ENFORCE (platform::dynload::cublasGemmStridedBatchedEx (
361
+ PADDLE_ENFORCE_CUDA_SUCCESS (platform::dynload::cublasGemmStridedBatchedEx (
360
362
handle, cuTransB, cuTransA, N, M, K, &alpha, B, CUDA_R_32F, ldb,
361
363
strideB, A, CUDA_R_32F, lda, strideA, &beta, C, CUDA_R_32F, ldc,
362
364
strideC, batchCount, CUDA_R_32F, algo));
0 commit comments