@@ -13,6 +13,7 @@ See the License for the specific language governing permissions and
13
13
limitations under the License. */
14
14
15
15
#include " paddle/fluid/operators/math/math_function.h"
16
+ #include < vector>
16
17
#include " paddle/fluid/framework/data_type.h"
17
18
#include " paddle/fluid/operators/math/math_function_impl.h"
18
19
#include " paddle/fluid/platform/float16.h"
@@ -161,7 +162,8 @@ void batched_gemm<platform::CPUDeviceContext, float16>(
161
162
const platform::CPUDeviceContext& context, const CBLAS_TRANSPOSE transA,
162
163
const CBLAS_TRANSPOSE transB, const int M, const int N, const int K,
163
164
const float16 alpha, const float16* A, const float16* B, const float16 beta,
164
- float16* C, const int batchCount, const int strideA, const int strideB) {
165
+ float16* C, const int batchCount, const int64_t strideA,
166
+ const int64_t strideB) {
165
167
PADDLE_THROW (" float16 batched_gemm not supported on CPU" );
166
168
}
167
169
@@ -172,7 +174,8 @@ void batched_gemm<platform::CPUDeviceContext, float>(
172
174
const platform::CPUDeviceContext& context, const CBLAS_TRANSPOSE transA,
173
175
const CBLAS_TRANSPOSE transB, const int M, const int N, const int K,
174
176
const float alpha, const float * A, const float * B, const float beta,
175
- float * C, const int batchCount, const int strideA, const int strideB) {
177
+ float * C, const int batchCount, const int64_t strideA,
178
+ const int64_t strideB) {
176
179
int lda = (transA == CblasNoTrans) ? K : M;
177
180
int ldb = (transB == CblasNoTrans) ? N : K;
178
181
int ldc = N;
@@ -194,7 +197,8 @@ void batched_gemm<platform::CPUDeviceContext, double>(
194
197
const platform::CPUDeviceContext& context, const CBLAS_TRANSPOSE transA,
195
198
const CBLAS_TRANSPOSE transB, const int M, const int N, const int K,
196
199
const double alpha, const double * A, const double * B, const double beta,
197
- double * C, const int batchCount, const int strideA, const int strideB) {
200
+ double * C, const int batchCount, const int64_t strideA,
201
+ const int64_t strideB) {
198
202
int lda = (transA == CblasNoTrans) ? K : M;
199
203
int ldb = (transB == CblasNoTrans) ? N : K;
200
204
int ldc = N;
@@ -220,7 +224,8 @@ void batched_gemm<platform::CPUDeviceContext, float>(
220
224
const platform::CPUDeviceContext& context, const CBLAS_TRANSPOSE transA,
221
225
const CBLAS_TRANSPOSE transB, const int M, const int N, const int K,
222
226
const float alpha, const float * A, const float * B, const float beta,
223
- float * C, const int batchCount, const int strideA, const int strideB) {
227
+ float * C, const int batchCount, const int64_t strideA,
228
+ const int64_t strideB) {
224
229
for (int k = 0 ; k < batchCount; ++k) {
225
230
const float * Ak = &A[k * strideA];
226
231
const float * Bk = &B[k * strideB];
@@ -235,7 +240,8 @@ void batched_gemm<platform::CPUDeviceContext, double>(
235
240
const platform::CPUDeviceContext& context, const CBLAS_TRANSPOSE transA,
236
241
const CBLAS_TRANSPOSE transB, const int M, const int N, const int K,
237
242
const double alpha, const double * A, const double * B, const double beta,
238
- double * C, const int batchCount, const int strideA, const int strideB) {
243
+ double * C, const int batchCount, const int64_t strideA,
244
+ const int64_t strideB) {
239
245
for (int k = 0 ; k < batchCount; ++k) {
240
246
const double * Ak = &A[k * strideA];
241
247
const double * Bk = &B[k * strideB];
0 commit comments