@@ -37,6 +37,7 @@ struct CBlas<float> {
37
37
libxsmm_sgemm (args...);
38
38
}
39
39
#endif
40
+
40
41
template <typename ... ARGS>
41
42
static void AXPY (ARGS... args) {
42
43
platform::dynload::cblas_saxpy (args...);
@@ -76,6 +77,7 @@ struct CBlas<double> {
76
77
libxsmm_dgemm (args...);
77
78
}
78
79
#endif
80
+
79
81
template <typename ... ARGS>
80
82
static void AXPY (ARGS... args) {
81
83
platform::dynload::cblas_daxpy (args...);
@@ -150,6 +152,7 @@ struct CBlas<double> {
150
152
}
151
153
};
152
154
#endif
155
+
153
156
template <>
154
157
struct CBlas <platform::float16> {
155
158
static void GEMM (...) { PADDLE_THROW (" float16 GEMM not supported on CPU" ); }
@@ -190,30 +193,48 @@ inline bool UseXSMM<platform::float16>(const int &m, const int &n, const int &k,
190
193
return false ;
191
194
}
192
195
193
- template <>
194
196
template <typename T>
195
- void Blas<platform::CPUDeviceContext>::GEMM(CBLAS_TRANSPOSE transA,
196
- CBLAS_TRANSPOSE transB, int M,
197
- int N, int K, T alpha, const T *A,
198
- const T *B, T beta, T *C) const {
199
- int lda = (transA == CblasNoTrans) ? K : M;
200
- int ldb = (transB == CblasNoTrans) ? N : K;
201
- int ldc = N;
197
+ inline void GEMM_WARP (CBLAS_ORDER order, CBLAS_TRANSPOSE transA,
198
+ CBLAS_TRANSPOSE transB, int M, int N, int K, T alpha,
199
+ const T *A, int lda, const T *B, int ldb, T beta, T *C,
200
+ int ldc) {
202
201
#ifdef PADDLE_WITH_LIBXSMM
203
- if (UseXSMM (M, N, K, transA != CblasNoTrans, transB != CblasNoTrans, alpha,
204
- beta)) {
202
+ if (UseXSMM<T> (M, N, K, transA != CblasNoTrans, transB != CblasNoTrans, alpha,
203
+ beta)) {
205
204
// Note: SMM use ColMajor
206
205
const char transa = ' N' ;
207
206
const char transb = ' N' ;
208
207
CBlas<T>::SMM_GEMM (&transa, &transb, &N, &M, &K, &alpha, B, &ldb, A, &lda,
209
208
&beta, C, &ldc);
210
- } else {
209
+ return ;
210
+ }
211
211
#endif
212
- CBlas<T>::GEMM (CblasRowMajor, transA, transB, M, N, K, alpha, A, lda, B,
213
- ldb, beta, C, ldc);
214
- #ifdef PADDLE_WITH_LIBXSMM
212
+
213
+ #ifdef PADDLE_MKL_SPLIT_GEMM
214
+ constexpr int bs = 2 ;
215
+ if (M % bs == 0 && transA == CblasNoTrans && transB == CblasNoTrans) {
216
+ for (int off = 0 ; off < M; off += bs) {
217
+ CBlas<T>::GEMM (CblasRowMajor, CblasNoTrans, CblasNoTrans, bs, N, K, alpha,
218
+ A + off * lda, lda, B, ldb, beta, C + off * ldb, ldc);
219
+ }
220
+ return ;
215
221
}
216
222
#endif
223
+ CBlas<T>::GEMM (CblasRowMajor, transA, transB, M, N, K, alpha, A, lda, B, ldb,
224
+ beta, C, ldc);
225
+ }
226
+
227
+ template <>
228
+ template <typename T>
229
+ void Blas<platform::CPUDeviceContext>::GEMM(CBLAS_TRANSPOSE transA,
230
+ CBLAS_TRANSPOSE transB, int M,
231
+ int N, int K, T alpha, const T *A,
232
+ const T *B, T beta, T *C) const {
233
+ int lda = (transA == CblasNoTrans) ? K : M;
234
+ int ldb = (transB == CblasNoTrans) ? N : K;
235
+ int ldc = N;
236
+ GEMM_WARP<T>(CblasRowMajor, transA, transB, M, N, K, alpha, A, lda, B, ldb,
237
+ beta, C, ldc);
217
238
}
218
239
219
240
template <>
@@ -222,9 +243,9 @@ void Blas<platform::CPUDeviceContext>::GEMM(bool transA, bool transB, int M,
222
243
int N, int K, T alpha, const T *A,
223
244
int lda, const T *B, int ldb,
224
245
T beta, T *C, int ldc) const {
225
- CBlas <T>:: GEMM (CblasRowMajor, transA == false ? CblasNoTrans : CblasTrans,
226
- transB == false ? CblasNoTrans : CblasTrans, M, N, K, alpha, A,
227
- lda, B, ldb, beta, C, ldc);
246
+ GEMM_WARP <T>(CblasRowMajor, transA == false ? CblasNoTrans : CblasTrans,
247
+ transB == false ? CblasNoTrans : CblasTrans, M, N, K, alpha, A,
248
+ lda, B, ldb, beta, C, ldc);
228
249
}
229
250
230
251
template <typename DeviceContext>
0 commit comments