Skip to content

Commit 2a06e30

Browse files
committed
Fix batch_gemm bugs
stride should be int64_t, not int
1 parent bfbbe19 commit 2a06e30

File tree

3 files changed

+21
-12
lines changed

3 files changed

+21
-12
lines changed

paddle/fluid/operators/math/math_function.cc

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@ See the License for the specific language governing permissions and
1313
limitations under the License. */
1414

1515
#include "paddle/fluid/operators/math/math_function.h"
16+
#include <vector>
1617
#include "paddle/fluid/framework/data_type.h"
1718
#include "paddle/fluid/operators/math/math_function_impl.h"
1819
#include "paddle/fluid/platform/float16.h"
@@ -161,7 +162,8 @@ void batched_gemm<platform::CPUDeviceContext, float16>(
161162
const platform::CPUDeviceContext& context, const CBLAS_TRANSPOSE transA,
162163
const CBLAS_TRANSPOSE transB, const int M, const int N, const int K,
163164
const float16 alpha, const float16* A, const float16* B, const float16 beta,
164-
float16* C, const int batchCount, const int strideA, const int strideB) {
165+
float16* C, const int batchCount, const int64_t strideA,
166+
const int64_t strideB) {
165167
PADDLE_THROW("float16 batched_gemm not supported on CPU");
166168
}
167169

@@ -172,7 +174,8 @@ void batched_gemm<platform::CPUDeviceContext, float>(
172174
const platform::CPUDeviceContext& context, const CBLAS_TRANSPOSE transA,
173175
const CBLAS_TRANSPOSE transB, const int M, const int N, const int K,
174176
const float alpha, const float* A, const float* B, const float beta,
175-
float* C, const int batchCount, const int strideA, const int strideB) {
177+
float* C, const int batchCount, const int64_t strideA,
178+
const int64_t strideB) {
176179
int lda = (transA == CblasNoTrans) ? K : M;
177180
int ldb = (transB == CblasNoTrans) ? N : K;
178181
int ldc = N;
@@ -194,7 +197,8 @@ void batched_gemm<platform::CPUDeviceContext, double>(
194197
const platform::CPUDeviceContext& context, const CBLAS_TRANSPOSE transA,
195198
const CBLAS_TRANSPOSE transB, const int M, const int N, const int K,
196199
const double alpha, const double* A, const double* B, const double beta,
197-
double* C, const int batchCount, const int strideA, const int strideB) {
200+
double* C, const int batchCount, const int64_t strideA,
201+
const int64_t strideB) {
198202
int lda = (transA == CblasNoTrans) ? K : M;
199203
int ldb = (transB == CblasNoTrans) ? N : K;
200204
int ldc = N;

paddle/fluid/operators/math/math_function.cu

Lines changed: 10 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@ See the License for the specific language governing permissions and
1313
limitations under the License. */
1414

1515
#define EIGEN_USE_GPU
16+
#include <vector>
1617
#include "paddle/fluid/framework/data_type.h"
1718
#include "paddle/fluid/operators/math/math_function.h"
1819
#include "paddle/fluid/operators/math/math_function_impl.h"
@@ -267,7 +268,8 @@ void batched_gemm<platform::CUDADeviceContext, float16>(
267268
const platform::CUDADeviceContext& context, const CBLAS_TRANSPOSE transA,
268269
const CBLAS_TRANSPOSE transB, const int M, const int N, const int K,
269270
const float16 alpha, const float16* A, const float16* B, const float16 beta,
270-
float16* C, const int batchCount, const int strideA, const int strideB) {
271+
float16* C, const int batchCount, const int64_t strideA,
272+
const int64_t strideB) {
271273
#if CUDA_VERSION >= 8000
272274
// Note that cublas follows fortran order, so the order is different from
273275
// the cblas convention.
@@ -278,7 +280,7 @@ void batched_gemm<platform::CUDADeviceContext, float16>(
278280
(transA == CblasNoTrans) ? CUBLAS_OP_N : CUBLAS_OP_T;
279281
cublasOperation_t cuTransB =
280282
(transB == CblasNoTrans) ? CUBLAS_OP_N : CUBLAS_OP_T;
281-
const int strideC = M * N;
283+
const int64_t strideC = M * N;
282284

283285
const half h_alpha = static_cast<const half>(alpha);
284286
const half h_beta = static_cast<const half>(beta);
@@ -303,7 +305,8 @@ void batched_gemm<platform::CUDADeviceContext, float>(
303305
const platform::CUDADeviceContext& context, const CBLAS_TRANSPOSE transA,
304306
const CBLAS_TRANSPOSE transB, const int M, const int N, const int K,
305307
const float alpha, const float* A, const float* B, const float beta,
306-
float* C, const int batchCount, const int strideA, const int strideB) {
308+
float* C, const int batchCount, const int64_t strideA,
309+
const int64_t strideB) {
307310
#if CUDA_VERSION >= 8000
308311
// Note that cublas follows fortran order, so the order is different from
309312
// the cblas convention.
@@ -314,7 +317,7 @@ void batched_gemm<platform::CUDADeviceContext, float>(
314317
(transA == CblasNoTrans) ? CUBLAS_OP_N : CUBLAS_OP_T;
315318
cublasOperation_t cuTransB =
316319
(transB == CblasNoTrans) ? CUBLAS_OP_N : CUBLAS_OP_T;
317-
const int strideC = M * N;
320+
const int64_t strideC = M * N;
318321

319322
PADDLE_ENFORCE(platform::dynload::cublasSgemmStridedBatched(
320323
context.cublas_handle(), cuTransB, cuTransA, N, M, K, &alpha, B, ldb,
@@ -329,7 +332,8 @@ void batched_gemm<platform::CUDADeviceContext, double>(
329332
const platform::CUDADeviceContext& context, const CBLAS_TRANSPOSE transA,
330333
const CBLAS_TRANSPOSE transB, const int M, const int N, const int K,
331334
const double alpha, const double* A, const double* B, const double beta,
332-
double* C, const int batchCount, const int strideA, const int strideB) {
335+
double* C, const int batchCount, const int64_t strideA,
336+
const int64_t strideB) {
333337
#if CUDA_VERSION >= 8000
334338
// Note that cublas follows fortran order, so the order is different from
335339
// the cblas convention.
@@ -340,7 +344,7 @@ void batched_gemm<platform::CUDADeviceContext, double>(
340344
(transA == CblasNoTrans) ? CUBLAS_OP_N : CUBLAS_OP_T;
341345
cublasOperation_t cuTransB =
342346
(transB == CblasNoTrans) ? CUBLAS_OP_N : CUBLAS_OP_T;
343-
const int strideC = M * N;
347+
const int64_t strideC = M * N;
344348

345349
PADDLE_ENFORCE(platform::dynload::cublasDgemmStridedBatched(
346350
context.cublas_handle(), cuTransB, cuTransA, N, M, K, &alpha, B, ldb,

paddle/fluid/operators/math/math_function.h

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@ limitations under the License. */
2626

2727
#ifndef LAPACK_FOUND
2828
extern "C" {
29-
#include <cblas.h>
29+
#include <cblas.h> // NOLINT
3030
int LAPACKE_sgetrf(int matrix_layout, int m, int n, float* a, int lda,
3131
int* ipiv);
3232
int LAPACKE_dgetrf(int matrix_layout, int m, int n, double* a, int lda,
@@ -39,6 +39,7 @@ int LAPACKE_dgetri(int matrix_layout, int n, double* a, int lda,
3939
#endif
4040

4141
#include <cmath>
42+
#include <vector>
4243

4344
#include "paddle/fluid/framework/eigen.h"
4445
#include "paddle/fluid/framework/tensor.h"
@@ -78,8 +79,8 @@ template <typename DeviceContext, typename T>
7879
void batched_gemm(const DeviceContext& context, const CBLAS_TRANSPOSE transA,
7980
const CBLAS_TRANSPOSE transB, const int M, const int N,
8081
const int K, const T alpha, const T* A, const T* B,
81-
const T beta, T* C, const int batchCount, const int strideA,
82-
const int strideB);
82+
const T beta, T* C, const int batchCount,
83+
const int64_t strideA, const int64_t strideB);
8384

8485
template <typename DeviceContext, typename T>
8586
void gemv(const DeviceContext& context, const bool trans_a, const int M,

0 commit comments

Comments
 (0)