@@ -268,6 +268,7 @@ void batched_gemm<platform::CUDADeviceContext, float16>(
268
268
const CBLAS_TRANSPOSE transB, const int M, const int N, const int K,
269
269
const float16 alpha, const float16* A, const float16* B, const float16 beta,
270
270
float16* C, const int batchCount, const int strideA, const int strideB) {
271
+ #if CUDA_VERSION >= 8000
271
272
// Note that cublas follows fortran order, so the order is different from
272
273
// the cblas convention.
273
274
int lda = (transA == CblasNoTrans) ? K : M;
@@ -289,7 +290,6 @@ void batched_gemm<platform::CUDADeviceContext, float16>(
289
290
PADDLE_ENFORCE_GE (context.GetComputeCapability (), 53 ,
290
291
" cublas Hgemm requires GPU compute capability >= 53" );
291
292
292
- #if CUDA_VERSION >= 8000
293
293
PADDLE_ENFORCE (platform::dynload::cublasHgemmStridedBatched (
294
294
context.cublas_handle (), cuTransB, cuTransA, N, M, K, &h_alpha, h_B, ldb,
295
295
strideB, h_A, lda, strideA, &h_beta, h_C, ldc, strideC, batchCount));
@@ -304,6 +304,7 @@ void batched_gemm<platform::CUDADeviceContext, float>(
304
304
const CBLAS_TRANSPOSE transB, const int M, const int N, const int K,
305
305
const float alpha, const float * A, const float * B, const float beta,
306
306
float * C, const int batchCount, const int strideA, const int strideB) {
307
+ #if CUDA_VERSION >= 8000
307
308
// Note that cublas follows fortran order, so the order is different from
308
309
// the cblas convention.
309
310
int lda = (transA == CblasNoTrans) ? K : M;
@@ -315,7 +316,6 @@ void batched_gemm<platform::CUDADeviceContext, float>(
315
316
(transB == CblasNoTrans) ? CUBLAS_OP_N : CUBLAS_OP_T;
316
317
const int strideC = M * N;
317
318
318
- #if CUDA_VERSION >= 8000
319
319
PADDLE_ENFORCE (platform::dynload::cublasSgemmStridedBatched (
320
320
context.cublas_handle (), cuTransB, cuTransA, N, M, K, &alpha, B, ldb,
321
321
strideB, A, lda, strideA, &beta, C, ldc, strideC, batchCount));
@@ -330,6 +330,7 @@ void batched_gemm<platform::CUDADeviceContext, double>(
330
330
const CBLAS_TRANSPOSE transB, const int M, const int N, const int K,
331
331
const double alpha, const double * A, const double * B, const double beta,
332
332
double * C, const int batchCount, const int strideA, const int strideB) {
333
+ #if CUDA_VERSION >= 8000
333
334
// Note that cublas follows fortran order, so the order is different from
334
335
// the cblas convention.
335
336
int lda = (transA == CblasNoTrans) ? K : M;
@@ -341,7 +342,6 @@ void batched_gemm<platform::CUDADeviceContext, double>(
341
342
(transB == CblasNoTrans) ? CUBLAS_OP_N : CUBLAS_OP_T;
342
343
const int strideC = M * N;
343
344
344
- #if CUDA_VERSION >= 8000
345
345
PADDLE_ENFORCE (platform::dynload::cublasDgemmStridedBatched (
346
346
context.cublas_handle (), cuTransB, cuTransA, N, M, K, &alpha, B, ldb,
347
347
strideB, A, lda, strideA, &beta, C, ldc, strideC, batchCount));
0 commit comments