@@ -13,6 +13,7 @@ See the License for the specific language governing permissions and
13
13
limitations under the License. */
14
14
15
15
#define EIGEN_USE_GPU
16
+ #include < vector>
16
17
#include " paddle/fluid/framework/data_type.h"
17
18
#include " paddle/fluid/operators/math/math_function.h"
18
19
#include " paddle/fluid/operators/math/math_function_impl.h"
@@ -267,7 +268,8 @@ void batched_gemm<platform::CUDADeviceContext, float16>(
267
268
const platform::CUDADeviceContext& context, const CBLAS_TRANSPOSE transA,
268
269
const CBLAS_TRANSPOSE transB, const int M, const int N, const int K,
269
270
const float16 alpha, const float16* A, const float16* B, const float16 beta,
270
- float16* C, const int batchCount, const int strideA, const int strideB) {
271
+ float16* C, const int batchCount, const int64_t strideA,
272
+ const int64_t strideB) {
271
273
#if CUDA_VERSION >= 8000
272
274
// Note that cublas follows fortran order, so the order is different from
273
275
// the cblas convention.
@@ -278,7 +280,7 @@ void batched_gemm<platform::CUDADeviceContext, float16>(
278
280
(transA == CblasNoTrans) ? CUBLAS_OP_N : CUBLAS_OP_T;
279
281
cublasOperation_t cuTransB =
280
282
(transB == CblasNoTrans) ? CUBLAS_OP_N : CUBLAS_OP_T;
281
- const int strideC = M * N;
283
+ const int64_t strideC = M * N;
282
284
283
285
const half h_alpha = static_cast <const half>(alpha);
284
286
const half h_beta = static_cast <const half>(beta);
@@ -303,7 +305,8 @@ void batched_gemm<platform::CUDADeviceContext, float>(
303
305
const platform::CUDADeviceContext& context, const CBLAS_TRANSPOSE transA,
304
306
const CBLAS_TRANSPOSE transB, const int M, const int N, const int K,
305
307
const float alpha, const float * A, const float * B, const float beta,
306
- float * C, const int batchCount, const int strideA, const int strideB) {
308
+ float * C, const int batchCount, const int64_t strideA,
309
+ const int64_t strideB) {
307
310
#if CUDA_VERSION >= 8000
308
311
// Note that cublas follows fortran order, so the order is different from
309
312
// the cblas convention.
@@ -314,7 +317,7 @@ void batched_gemm<platform::CUDADeviceContext, float>(
314
317
(transA == CblasNoTrans) ? CUBLAS_OP_N : CUBLAS_OP_T;
315
318
cublasOperation_t cuTransB =
316
319
(transB == CblasNoTrans) ? CUBLAS_OP_N : CUBLAS_OP_T;
317
- const int strideC = M * N;
320
+ const int64_t strideC = M * N;
318
321
319
322
PADDLE_ENFORCE (platform::dynload::cublasSgemmStridedBatched (
320
323
context.cublas_handle (), cuTransB, cuTransA, N, M, K, &alpha, B, ldb,
@@ -329,7 +332,8 @@ void batched_gemm<platform::CUDADeviceContext, double>(
329
332
const platform::CUDADeviceContext& context, const CBLAS_TRANSPOSE transA,
330
333
const CBLAS_TRANSPOSE transB, const int M, const int N, const int K,
331
334
const double alpha, const double * A, const double * B, const double beta,
332
- double * C, const int batchCount, const int strideA, const int strideB) {
335
+ double * C, const int batchCount, const int64_t strideA,
336
+ const int64_t strideB) {
333
337
#if CUDA_VERSION >= 8000
334
338
// Note that cublas follows fortran order, so the order is different from
335
339
// the cblas convention.
@@ -340,7 +344,7 @@ void batched_gemm<platform::CUDADeviceContext, double>(
340
344
(transA == CblasNoTrans) ? CUBLAS_OP_N : CUBLAS_OP_T;
341
345
cublasOperation_t cuTransB =
342
346
(transB == CblasNoTrans) ? CUBLAS_OP_N : CUBLAS_OP_T;
343
- const int strideC = M * N;
347
+ const int64_t strideC = M * N;
344
348
345
349
PADDLE_ENFORCE (platform::dynload::cublasDgemmStridedBatched (
346
350
context.cublas_handle (), cuTransB, cuTransA, N, M, K, &alpha, B, ldb,
0 commit comments