@@ -73,6 +73,16 @@ struct CBlas<float> {
73
73
platform::dynload::cblas_sgemv (args...);
74
74
}
75
75
76
+ template <typename ... ARGS>
77
+ static float DOT (ARGS... args) {
78
+ return platform::dynload::cblas_sdot (args...);
79
+ }
80
+
81
+ template <typename ... ARGS>
82
+ static void SCAL (ARGS... args) {
83
+ platform::dynload::cblas_sscal (args...);
84
+ }
85
+
76
86
template <typename ... ARGS>
77
87
static void GEMM_BATCH (ARGS... args) {
78
88
platform::dynload::cblas_sgemm_batch (args...);
@@ -87,6 +97,11 @@ struct CBlas<float> {
87
97
static void VMUL (ARGS... args) {
88
98
platform::dynload::vsMul (args...);
89
99
}
100
+
101
+ template <typename ... ARGS>
102
+ static void VEXP (ARGS... args) {
103
+ platform::dynload::vsExp (args...);
104
+ }
90
105
};
91
106
92
107
template <>
@@ -138,6 +153,16 @@ struct CBlas<double> {
138
153
platform::dynload::cblas_dgemv (args...);
139
154
}
140
155
156
+ template <typename ... ARGS>
157
+ static double DOT (ARGS... args) {
158
+ return platform::dynload::cblas_ddot (args...);
159
+ }
160
+
161
+ template <typename ... ARGS>
162
+ static void SCAL (ARGS... args) {
163
+ platform::dynload::cblas_dscal (args...);
164
+ }
165
+
141
166
template <typename ... ARGS>
142
167
static void GEMM_BATCH (ARGS... args) {
143
168
platform::dynload::cblas_dgemm_batch (args...);
@@ -152,6 +177,11 @@ struct CBlas<double> {
152
177
static void VMUL (ARGS... args) {
153
178
platform::dynload::vdMul (args...);
154
179
}
180
+
181
+ template <typename ... ARGS>
182
+ static void VEXP (ARGS... args) {
183
+ platform::dynload::vdExp (args...);
184
+ }
155
185
};
156
186
157
187
#else
@@ -210,71 +240,16 @@ struct CBlas<platform::float16> {
210
240
PADDLE_THROW (" float16 SMM_GEMM not supported on CPU" );
211
241
}
212
242
static void VMUL (...) { PADDLE_THROW (" float16 VMUL not supported on CPU" ); }
243
+ static void VEXP (...) { PADDLE_THROW (" float16 VEXP not supported on CPU" ); }
244
+ static void DOT (...) { PADDLE_THROW (" float16 DOT not supported on CPU" ); };
245
+ static void SCAL (...) { PADDLE_THROW (" float16 SCAL not supported on CPU" ); };
213
246
#ifdef PADDLE_WITH_MKLML
214
247
static void GEMM_BATCH (...) {
215
248
PADDLE_THROW (" float16 GEMM_BATCH not supported on CPU" );
216
249
}
217
250
#endif
218
251
};
219
252
220
- template <typename T>
221
- inline bool UseXSMM (const int &m, const int &n, const int &k, bool transa,
222
- bool transb, const T &alpha, const T &beta) {
223
- #ifdef PADDLE_WITH_LIBXSMM
224
- // Refer to https://github.com/hfp/libxsmm/blob/master/README.md
225
- // But the threshold is custom
226
- constexpr int LIBXSMM_THRESHOLD = 20 * 20 * 20 ;
227
- if (m * n * k > LIBXSMM_THRESHOLD || transa || transb ||
228
- std::abs<T>(alpha - static_cast <T>(1 ) >
229
- std::numeric_limits<T>::epsilon ()) ||
230
- std::abs<T>(beta) > std::numeric_limits<T>::epsilon ()) {
231
- return false ;
232
- } else {
233
- return true ;
234
- }
235
- #endif
236
- return false ;
237
- }
238
-
239
- template <>
240
- inline bool UseXSMM<platform::float16>(const int &m, const int &n, const int &k,
241
- bool transa, bool transb,
242
- const platform::float16 &alpha,
243
- const platform::float16 &beta) {
244
- return false ;
245
- }
246
-
247
- template <typename T>
248
- inline void GEMM_WARP (CBLAS_ORDER order, CBLAS_TRANSPOSE transA,
249
- CBLAS_TRANSPOSE transB, int M, int N, int K, T alpha,
250
- const T *A, int lda, const T *B, int ldb, T beta, T *C,
251
- int ldc) {
252
- #ifdef PADDLE_WITH_LIBXSMM
253
- if (UseXSMM<T>(M, N, K, transA != CblasNoTrans, transB != CblasNoTrans, alpha,
254
- beta)) {
255
- // Note: SMM use ColMajor
256
- const char transa = ' N' ;
257
- const char transb = ' N' ;
258
- CBlas<T>::SMM_GEMM (&transa, &transb, &N, &M, &K, &alpha, B, &ldb, A, &lda,
259
- &beta, C, &ldc);
260
- return ;
261
- }
262
- #endif
263
-
264
- #ifdef PADDLE_MKL_SPLIT_GEMM
265
- constexpr int bs = 2 ;
266
- if (M % bs == 0 && transA == CblasNoTrans && transB == CblasNoTrans) {
267
- for (int off = 0 ; off < M; off += bs) {
268
- CBlas<T>::GEMM (CblasRowMajor, CblasNoTrans, CblasNoTrans, bs, N, K, alpha,
269
- A + off * lda, lda, B, ldb, beta, C + off * ldb, ldc);
270
- }
271
- return ;
272
- }
273
- #endif
274
- CBlas<T>::GEMM (CblasRowMajor, transA, transB, M, N, K, alpha, A, lda, B, ldb,
275
- beta, C, ldc);
276
- }
277
-
278
253
#ifdef PADDLE_WITH_MKLML
279
254
template <>
280
255
template <typename T>
@@ -319,8 +294,8 @@ void Blas<platform::CPUDeviceContext>::GEMM(CBLAS_TRANSPOSE transA,
319
294
int lda = (transA == CblasNoTrans) ? K : M;
320
295
int ldb = (transB == CblasNoTrans) ? N : K;
321
296
int ldc = N;
322
- GEMM_WARP <T>(CblasRowMajor, transA, transB, M, N, K, alpha, A, lda, B, ldb,
323
- beta, C, ldc);
297
+ CBlas <T>:: GEMM (CblasRowMajor, transA, transB, M, N, K, alpha, A, lda, B, ldb,
298
+ beta, C, ldc);
324
299
}
325
300
326
301
template <>
@@ -329,9 +304,20 @@ void Blas<platform::CPUDeviceContext>::GEMM(bool transA, bool transB, int M,
329
304
int N, int K, T alpha, const T *A,
330
305
int lda, const T *B, int ldb,
331
306
T beta, T *C, int ldc) const {
332
- GEMM_WARP<T>(CblasRowMajor, transA == false ? CblasNoTrans : CblasTrans,
333
- transB == false ? CblasNoTrans : CblasTrans, M, N, K, alpha, A,
334
- lda, B, ldb, beta, C, ldc);
307
+ CBlas<T>::GEMM (CblasRowMajor, transA == false ? CblasNoTrans : CblasTrans,
308
+ transB == false ? CblasNoTrans : CblasTrans, M, N, K, alpha, A,
309
+ lda, B, ldb, beta, C, ldc);
310
+ }
311
+
312
+ template <>
313
+ template <typename T>
314
+ void Blas<platform::CPUDeviceContext>::GEMM(CBLAS_TRANSPOSE transA,
315
+ CBLAS_TRANSPOSE transB, int M,
316
+ int N, int K, T alpha, const T *A,
317
+ int lda, const T *B, int ldb,
318
+ T beta, T *C, int ldc) const {
319
+ CBlas<T>::GEMM (CblasRowMajor, transA, transB, M, N, K, alpha, A, lda, B, ldb,
320
+ beta, C, ldc);
335
321
}
336
322
337
323
template <typename DeviceContext>
@@ -399,6 +385,47 @@ void Blas<platform::CPUDeviceContext>::VMUL(int n, const T *x, const T *y,
399
385
#endif
400
386
}
401
387
388
+ template <>
389
+ template <typename T>
390
+ void Blas<platform::CPUDeviceContext>::VEXP(int n, const T *x, T *y) const {
391
+ #ifdef PADDLE_WITH_MKLML
392
+ CBlas<T>::VEXP (n, x, y);
393
+ #else
394
+ // try to find if openblas support vexp
395
+ for (int i = 0 ; i < n; ++i) {
396
+ y[i] = std::exp (x[i]);
397
+ }
398
+ #endif
399
+ }
400
+
401
+ template <>
402
+ template <typename T>
403
+ T Blas<platform::CPUDeviceContext>::DOT(int n, const T *x, const T *y) const {
404
+ #ifdef PADDLE_WITH_MKLML
405
+ return CBlas<T>::DOT (n, x, 1 , y, 1 );
406
+ #else
407
+ // try to find if openblas support cblas_dot
408
+ T sum = 0 ;
409
+ for (int i = 0 ; i < n; ++i) {
410
+ sum += x[i] * y[i];
411
+ }
412
+ return sum;
413
+ #endif
414
+ }
415
+
416
+ template <>
417
+ template <typename T>
418
+ void Blas<platform::CPUDeviceContext>::SCAL(int n, const T a, T *x) const {
419
+ #ifdef PADDLE_WITH_MKLML
420
+ CBlas<T>::SCAL (n, a, x, 1 );
421
+ #else
422
+ // try to find if openblas support cblas_scal
423
+ for (int i = 0 ; i < n; ++i) {
424
+ x[i] = a * x[i];
425
+ }
426
+ #endif
427
+ }
428
+
402
429
template <>
403
430
template <typename T>
404
431
void Blas<platform::CPUDeviceContext>::GEMV(bool trans_a, int M, int N, T alpha,
@@ -440,6 +467,42 @@ void Blas<platform::CPUDeviceContext>::BatchedGEMM(
440
467
#endif
441
468
}
442
469
470
+ template <typename DeviceContext>
471
+ template <typename T>
472
+ void Blas<DeviceContext>::MatMul(const int M, const int N, const int K,
473
+ const T *A, const T *B, T *C) const {
474
+ this ->template GEMM <T>(CblasRowMajor, CblasNoTrans, CblasNoTrans, M, N, K,
475
+ static_cast <T>(1 ), A, K, B, N, static_cast <T>(0 ), C,
476
+ N);
477
+ }
478
+
479
+ template <>
480
+ template <typename T>
481
+ void Blas<platform::CPUDeviceContext>::MatMul(const int M, const int N,
482
+ const int K, const T *A,
483
+ const T *B, T *C) const {
484
+ #ifdef PADDLE_WITH_LIBXSMM
485
+ // Refer to https://github.com/hfp/libxsmm/blob/master/README.md
486
+ // But the threshold is custom constexpr int LIBXSMM_THRESHOLD = 20 * 20 * 20;
487
+
488
+ // Since the matrix is very small,
489
+ // so the unit of calculation is already very fast,
490
+ // and the if( M*N*K < LIBXSMM_THRESHOLD) would be overhead,
491
+ // use xsmm directly.
492
+ // Note: SMM use ColMajor
493
+ const char transa = ' N' ;
494
+ const char transb = ' N' ;
495
+ const T alpha = static_cast <T>(1 );
496
+ const T beta = static_cast <T>(0 );
497
+ CBlas<T>::SMM_GEMM (&transa, &transb, &N, &M, &K, &alpha, B, &N, A, &K, &beta,
498
+ C, &N);
499
+ return ;
500
+ #endif
501
+
502
+ CBlas<T>::GEMM (CblasRowMajor, CblasNoTrans, CblasNoTrans, M, N, K,
503
+ static_cast <T>(1 ), A, K, B, N, static_cast <T>(0 ), C, N);
504
+ }
505
+
443
506
template <typename DeviceContext>
444
507
template <typename T>
445
508
void Blas<DeviceContext>::MatMul(const framework::Tensor &mat_a,
0 commit comments