Skip to content

Commit 617e790

Browse files
authored
fix cuda 7.5 compile error (#9885)
1 parent 859fedf commit 617e790

File tree

1 file changed

+13
-0
lines changed

1 file changed

+13
-0
lines changed

paddle/fluid/operators/math/math_function.cu

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -288,9 +288,14 @@ void batched_gemm<platform::CUDADeviceContext, float16>(
288288
// TODO(kexinzhao): add processing code for compute capability < 53 case
289289
PADDLE_ENFORCE_GE(context.GetComputeCapability(), 53,
290290
"cublas Hgemm requires GPU compute capability >= 53");
291+
292+
#if CUDA_VERSION >= 8000
291293
PADDLE_ENFORCE(platform::dynload::cublasHgemmStridedBatched(
292294
context.cublas_handle(), cuTransB, cuTransA, N, M, K, &h_alpha, h_B, ldb,
293295
strideB, h_A, lda, strideA, &h_beta, h_C, ldc, strideC, batchCount));
296+
#else
297+
PADDLE_ENFORCE(false, "HgemmStridedBatched is not supported on cuda <= 7.5");
298+
#endif
294299
}
295300

296301
template <>
@@ -310,9 +315,13 @@ void batched_gemm<platform::CUDADeviceContext, float>(
310315
(transB == CblasNoTrans) ? CUBLAS_OP_N : CUBLAS_OP_T;
311316
const int strideC = M * N;
312317

318+
#if CUDA_VERSION >= 8000
313319
PADDLE_ENFORCE(platform::dynload::cublasSgemmStridedBatched(
314320
context.cublas_handle(), cuTransB, cuTransA, N, M, K, &alpha, B, ldb,
315321
strideB, A, lda, strideA, &beta, C, ldc, strideC, batchCount));
322+
#else
323+
PADDLE_ENFORCE(false, "SgemmStridedBatched is not supported on cuda <= 7.5");
324+
#endif
316325
}
317326

318327
template <>
@@ -332,9 +341,13 @@ void batched_gemm<platform::CUDADeviceContext, double>(
332341
(transB == CblasNoTrans) ? CUBLAS_OP_N : CUBLAS_OP_T;
333342
const int strideC = M * N;
334343

344+
#if CUDA_VERSION >= 8000
335345
PADDLE_ENFORCE(platform::dynload::cublasDgemmStridedBatched(
336346
context.cublas_handle(), cuTransB, cuTransA, N, M, K, &alpha, B, ldb,
337347
strideB, A, lda, strideA, &beta, C, ldc, strideC, batchCount));
348+
#else
349+
PADDLE_ENFORCE(false, "DgemmStridedBatched is not supported on cuda <= 7.5");
350+
#endif
338351
}
339352

340353
template <>

0 commit comments

Comments
 (0)