Skip to content

Commit 3b5f354

Browse files
silingtong123liupluswei
authored andcommitted
Modify PADDLE_ENFORCE to PADDLE_ENFORCE_CUDA_SUCCESS (#19247)
* add PADDLE_ENFORCE_CUDA_SUCCESS, test=develop (#19211) * test=develop,Modify PADDLE_ENFORCE to PADDLE_ENFORCE_CUDA_SUCCESS
1 parent 305bd25 commit 3b5f354

File tree

3 files changed

+105
-13
lines changed

3 files changed

+105
-13
lines changed

paddle/fluid/operators/math/blas_impl.cu.h

Lines changed: 15 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -31,23 +31,24 @@ template <>
3131
struct CUBlas<float> {
3232
template <typename... ARGS>
3333
static void GEMM(ARGS... args) {
34-
PADDLE_ENFORCE(platform::dynload::cublasSgemm(args...));
34+
PADDLE_ENFORCE_CUDA_SUCCESS(platform::dynload::cublasSgemm(args...));
3535
}
3636

3737
template <typename... ARGS>
3838
static void AXPY(ARGS... args) {
39-
PADDLE_ENFORCE(platform::dynload::cublasSaxpy(args...));
39+
PADDLE_ENFORCE_CUDA_SUCCESS(platform::dynload::cublasSaxpy(args...));
4040
}
4141

4242
template <typename... ARGS>
4343
static void GEMV(ARGS... args) {
44-
PADDLE_ENFORCE(platform::dynload::cublasSgemv(args...));
44+
PADDLE_ENFORCE_CUDA_SUCCESS(platform::dynload::cublasSgemv(args...));
4545
}
4646

4747
template <typename... ARGS>
4848
static void GEMM_STRIDED_BATCH(ARGS... args) {
4949
#if CUDA_VERSION >= 8000
50-
PADDLE_ENFORCE(platform::dynload::cublasSgemmStridedBatched(args...));
50+
PADDLE_ENFORCE_CUDA_SUCCESS(
51+
platform::dynload::cublasSgemmStridedBatched(args...));
5152
#else
5253
PADDLE_THROW("SgemmStridedBatched is not supported on cuda <= 7.5");
5354
#endif
@@ -69,7 +70,7 @@ struct CUBlas<float> {
6970
VLOG(5) << "use_tensor_op_math: "
7071
<< (dev_ctx->tensor_core_available() ? "True" : "False");
7172
dev_ctx->TensorCoreCublasCallIfAvailable([&](cublasHandle_t handle) {
72-
PADDLE_ENFORCE(platform::dynload::cublasSgemmEx(
73+
PADDLE_ENFORCE_CUDA_SUCCESS(platform::dynload::cublasSgemmEx(
7374
handle, transa, transb, m, n, k, alpha, A, Atype, lda, B, Btype, ldb,
7475
beta, C, Ctype, ldc));
7576
});
@@ -83,23 +84,24 @@ template <>
8384
struct CUBlas<double> {
8485
template <typename... ARGS>
8586
static void GEMM(ARGS... args) {
86-
PADDLE_ENFORCE(platform::dynload::cublasDgemm(args...));
87+
PADDLE_ENFORCE_CUDA_SUCCESS(platform::dynload::cublasDgemm(args...));
8788
}
8889

8990
template <typename... ARGS>
9091
static void AXPY(ARGS... args) {
91-
PADDLE_ENFORCE(platform::dynload::cublasDaxpy(args...));
92+
PADDLE_ENFORCE_CUDA_SUCCESS(platform::dynload::cublasDaxpy(args...));
9293
}
9394

9495
template <typename... ARGS>
9596
static void GEMV(ARGS... args) {
96-
PADDLE_ENFORCE(platform::dynload::cublasDgemv(args...));
97+
PADDLE_ENFORCE_CUDA_SUCCESS(platform::dynload::cublasDgemv(args...));
9798
}
9899

99100
template <typename... ARGS>
100101
static void GEMM_STRIDED_BATCH(ARGS... args) {
101102
#if CUDA_VERSION >= 8000
102-
PADDLE_ENFORCE(platform::dynload::cublasDgemmStridedBatched(args...));
103+
PADDLE_ENFORCE_CUDA_SUCCESS(
104+
platform::dynload::cublasDgemmStridedBatched(args...));
103105
#else
104106
PADDLE_THROW("DgemmStridedBatched is not supported on cuda <= 7.5");
105107
#endif
@@ -120,7 +122,7 @@ struct CUBlas<platform::float16> {
120122
const float16 *alpha, const float16 *A, int lda,
121123
const float16 *B, int ldb, const float16 *beta, float16 *C,
122124
int ldc) {
123-
PADDLE_ENFORCE(
125+
PADDLE_ENFORCE_CUDA_SUCCESS(
124126
platform::dynload::cublasHgemm(handle, transa, transb, m, n, k,
125127
reinterpret_cast<const __half *>(alpha),
126128
reinterpret_cast<const __half *>(A), lda,
@@ -140,7 +142,7 @@ struct CUBlas<platform::float16> {
140142
long long int strideC, // NOLINT
141143
int batchCount) {
142144
#if CUDA_VERSION >= 8000
143-
PADDLE_ENFORCE(platform::dynload::cublasHgemmStridedBatched(
145+
PADDLE_ENFORCE_CUDA_SUCCESS(platform::dynload::cublasHgemmStridedBatched(
144146
handle, transa, transb, m, n, k,
145147
reinterpret_cast<const __half *>(alpha),
146148
reinterpret_cast<const __half *>(A), lda, strideA,
@@ -174,7 +176,7 @@ struct CUBlas<platform::float16> {
174176
#endif // CUDA_VERSION >= 9000
175177

176178
dev_ctx->TensorCoreCublasCallIfAvailable([&](cublasHandle_t handle) {
177-
PADDLE_ENFORCE(platform::dynload::cublasGemmEx(
179+
PADDLE_ENFORCE_CUDA_SUCCESS(platform::dynload::cublasGemmEx(
178180
handle, transa, transb, m, n, k, alpha, A, Atype, lda, B, Btype, ldb,
179181
beta, C, Ctype, ldc, computeType, algo));
180182
});
@@ -356,7 +358,7 @@ void Blas<platform::CUDADeviceContext>::BatchedGEMM(
356358
<< (use_tensor_op_math ? "True" : "False");
357359

358360
context_.TensorCoreCublasCallIfAvailable([&](cublasHandle_t handle) {
359-
PADDLE_ENFORCE(platform::dynload::cublasGemmStridedBatchedEx(
361+
PADDLE_ENFORCE_CUDA_SUCCESS(platform::dynload::cublasGemmStridedBatchedEx(
360362
handle, cuTransB, cuTransA, N, M, K, &alpha, B, CUDA_R_32F, ldb,
361363
strideB, A, CUDA_R_32F, lda, strideA, &beta, C, CUDA_R_32F, ldc,
362364
strideC, batchCount, CUDA_R_32F, algo));

paddle/fluid/platform/enforce.h

Lines changed: 47 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -236,6 +236,31 @@ inline void throw_on_error(ncclResult_t stat, const std::string& msg) {
236236
#endif // __APPLE__ and windows
237237
#endif // PADDLE_WITH_CUDA
238238

239+
#ifdef PADDLE_WITH_CUDA
240+
namespace details {
241+
242+
template <typename T>
243+
struct CudaStatusType {};
244+
245+
#define DEFINE_CUDA_STATUS_TYPE(type, success_value) \
246+
template <> \
247+
struct CudaStatusType<type> { \
248+
using Type = type; \
249+
static constexpr Type kSuccess = success_value; \
250+
}
251+
252+
DEFINE_CUDA_STATUS_TYPE(cudaError_t, cudaSuccess);
253+
DEFINE_CUDA_STATUS_TYPE(curandStatus_t, CURAND_STATUS_SUCCESS);
254+
DEFINE_CUDA_STATUS_TYPE(cudnnStatus_t, CUDNN_STATUS_SUCCESS);
255+
DEFINE_CUDA_STATUS_TYPE(cublasStatus_t, CUBLAS_STATUS_SUCCESS);
256+
257+
#if !defined(__APPLE__) && !defined(_WIN32)
258+
DEFINE_CUDA_STATUS_TYPE(ncclResult_t, ncclSuccess);
259+
#endif
260+
261+
} // namespace details
262+
#endif
263+
239264
#define PADDLE_THROW(...) \
240265
do { \
241266
throw ::paddle::platform::EnforceNotMet( \
@@ -256,6 +281,28 @@ inline void throw_on_error(ncclResult_t stat, const std::string& msg) {
256281
} \
257282
} while (0)
258283

284+
#ifdef PADDLE_WITH_CUDA
285+
#define PADDLE_ENFORCE_CUDA_SUCCESS(COND, ...) \
286+
do { \
287+
auto __cond__ = (COND); \
288+
using __CUDA_STATUS_TYPE__ = decltype(__cond__); \
289+
constexpr auto __success_type__ = \
290+
::paddle::platform::details::CudaStatusType< \
291+
__CUDA_STATUS_TYPE__>::kSuccess; \
292+
if (UNLIKELY(__cond__ != __success_type__)) { \
293+
try { \
294+
::paddle::platform::throw_on_error( \
295+
__cond__, ::paddle::string::Sprintf(__VA_ARGS__)); \
296+
} catch (...) { \
297+
throw ::paddle::platform::EnforceNotMet(std::current_exception(), \
298+
__FILE__, __LINE__); \
299+
} \
300+
} \
301+
} while (0)
302+
303+
#undef DEFINE_CUDA_STATUS_TYPE
304+
#endif
305+
259306
#define PADDLE_THROW_EOF() \
260307
do { \
261308
throw ::paddle::platform::EOFException("There is no next data.", __FILE__, \

paddle/fluid/platform/enforce_test.cc

Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -253,3 +253,46 @@ TEST(EOF_EXCEPTION, THROW_EOF) {
253253
}
254254
EXPECT_TRUE(caught_eof);
255255
}
256+
257+
#ifdef PADDLE_WITH_CUDA
258+
template <typename T>
259+
bool CheckCudaStatusSuccess(T value, const std::string& msg = "success") {
260+
PADDLE_ENFORCE_CUDA_SUCCESS(value, msg);
261+
return true;
262+
}
263+
264+
template <typename T>
265+
bool CheckCudaStatusFailure(
266+
T value, const std::string& msg = "self-defined cuda status failed") {
267+
try {
268+
PADDLE_ENFORCE_CUDA_SUCCESS(value, msg);
269+
return false;
270+
} catch (paddle::platform::EnforceNotMet& error) {
271+
std::string ex_msg = error.what();
272+
return ex_msg.find(msg) != std::string::npos;
273+
}
274+
}
275+
276+
TEST(enforce, cuda_success) {
277+
EXPECT_TRUE(CheckCudaStatusSuccess(cudaSuccess));
278+
EXPECT_TRUE(CheckCudaStatusFailure(cudaErrorInvalidValue));
279+
EXPECT_TRUE(CheckCudaStatusFailure(cudaErrorMemoryAllocation));
280+
281+
EXPECT_TRUE(CheckCudaStatusSuccess(CURAND_STATUS_SUCCESS));
282+
EXPECT_TRUE(CheckCudaStatusFailure(CURAND_STATUS_VERSION_MISMATCH));
283+
EXPECT_TRUE(CheckCudaStatusFailure(CURAND_STATUS_NOT_INITIALIZED));
284+
285+
EXPECT_TRUE(CheckCudaStatusSuccess(CUDNN_STATUS_SUCCESS));
286+
EXPECT_TRUE(CheckCudaStatusFailure(CUDNN_STATUS_NOT_INITIALIZED));
287+
EXPECT_TRUE(CheckCudaStatusFailure(CUDNN_STATUS_ALLOC_FAILED));
288+
289+
EXPECT_TRUE(CheckCudaStatusSuccess(CUBLAS_STATUS_SUCCESS));
290+
EXPECT_TRUE(CheckCudaStatusFailure(CUBLAS_STATUS_NOT_INITIALIZED));
291+
EXPECT_TRUE(CheckCudaStatusFailure(CUBLAS_STATUS_INVALID_VALUE));
292+
#if !defined(__APPLE__) && !defined(_WIN32)
293+
EXPECT_TRUE(CheckCudaStatusSuccess(ncclSuccess));
294+
EXPECT_TRUE(CheckCudaStatusFailure(ncclUnhandledCudaError));
295+
EXPECT_TRUE(CheckCudaStatusFailure(ncclSystemError));
296+
#endif
297+
}
298+
#endif

0 commit comments

Comments
 (0)