@@ -288,9 +288,14 @@ void batched_gemm<platform::CUDADeviceContext, float16>(
288
288
// TODO(kexinzhao): add processing code for compute capability < 53 case
289
289
PADDLE_ENFORCE_GE (context.GetComputeCapability (), 53 ,
290
290
" cublas Hgemm requires GPU compute capability >= 53" );
291
+
292
+ #if CUDA_VERSION >= 8000
291
293
PADDLE_ENFORCE (platform::dynload::cublasHgemmStridedBatched (
292
294
context.cublas_handle (), cuTransB, cuTransA, N, M, K, &h_alpha, h_B, ldb,
293
295
strideB, h_A, lda, strideA, &h_beta, h_C, ldc, strideC, batchCount));
296
+ #else
297
+ PADDLE_ENFORCE (false , " HgemmStridedBatched is not supported on cuda <= 7.5" );
298
+ #endif
294
299
}
295
300
296
301
template <>
@@ -310,9 +315,13 @@ void batched_gemm<platform::CUDADeviceContext, float>(
310
315
(transB == CblasNoTrans) ? CUBLAS_OP_N : CUBLAS_OP_T;
311
316
const int strideC = M * N;
312
317
318
+ #if CUDA_VERSION >= 8000
313
319
PADDLE_ENFORCE (platform::dynload::cublasSgemmStridedBatched (
314
320
context.cublas_handle (), cuTransB, cuTransA, N, M, K, &alpha, B, ldb,
315
321
strideB, A, lda, strideA, &beta, C, ldc, strideC, batchCount));
322
+ #else
323
+ PADDLE_ENFORCE (false , " SgemmStridedBatched is not supported on cuda <= 7.5" );
324
+ #endif
316
325
}
317
326
318
327
template <>
@@ -332,9 +341,13 @@ void batched_gemm<platform::CUDADeviceContext, double>(
332
341
(transB == CblasNoTrans) ? CUBLAS_OP_N : CUBLAS_OP_T;
333
342
const int strideC = M * N;
334
343
344
+ #if CUDA_VERSION >= 8000
335
345
PADDLE_ENFORCE (platform::dynload::cublasDgemmStridedBatched (
336
346
context.cublas_handle (), cuTransB, cuTransA, N, M, K, &alpha, B, ldb,
337
347
strideB, A, lda, strideA, &beta, C, ldc, strideC, batchCount));
348
+ #else
349
+ PADDLE_ENFORCE (false , " DgemmStridedBatched is not supported on cuda <= 7.5" );
350
+ #endif
338
351
}
339
352
340
353
template <>
0 commit comments