@@ -224,7 +224,8 @@ void batched_gemm<platform::CPUDeviceContext, float>(
224
224
const platform::CPUDeviceContext& context, const CBLAS_TRANSPOSE transA,
225
225
const CBLAS_TRANSPOSE transB, const int M, const int N, const int K,
226
226
const float alpha, const float * A, const float * B, const float beta,
227
- float * C, const int batchCount, const int strideA, const int strideB) {
227
+ float * C, const int batchCount, const int64_t strideA,
228
+ const int64_t strideB) {
228
229
for (int k = 0 ; k < batchCount; ++k) {
229
230
const float * Ak = &A[k * strideA];
230
231
const float * Bk = &B[k * strideB];
@@ -239,7 +240,8 @@ void batched_gemm<platform::CPUDeviceContext, double>(
239
240
const platform::CPUDeviceContext& context, const CBLAS_TRANSPOSE transA,
240
241
const CBLAS_TRANSPOSE transB, const int M, const int N, const int K,
241
242
const double alpha, const double * A, const double * B, const double beta,
242
- double * C, const int batchCount, const int strideA, const int strideB) {
243
+ double * C, const int batchCount, const int64_t strideA,
244
+ const int64_t strideB) {
243
245
for (int k = 0 ; k < batchCount; ++k) {
244
246
const double * Ak = &A[k * strideA];
245
247
const double * Bk = &B[k * strideB];
0 commit comments