Skip to content

Commit a916c52

Browse files
committed
refine gemm
1 parent 961e754 commit a916c52

File tree

1 file changed

+35
-29
lines changed

1 file changed

+35
-29
lines changed

paddle/fluid/operators/math/blas_impl.h

Lines changed: 35 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,7 @@ struct CBlas<float> {
3737
libxsmm_sgemm(args...);
3838
}
3939
#endif
40+
4041
template <typename... ARGS>
4142
static void AXPY(ARGS... args) {
4243
platform::dynload::cblas_saxpy(args...);
@@ -76,6 +77,7 @@ struct CBlas<double> {
7677
libxsmm_dgemm(args...);
7778
}
7879
#endif
80+
7981
template <typename... ARGS>
8082
static void AXPY(ARGS... args) {
8183
platform::dynload::cblas_daxpy(args...);
@@ -150,6 +152,7 @@ struct CBlas<double> {
150152
}
151153
};
152154
#endif
155+
153156
template <>
154157
struct CBlas<platform::float16> {
155158
static void GEMM(...) { PADDLE_THROW("float16 GEMM not supported on CPU"); }
@@ -190,45 +193,48 @@ inline bool UseXSMM<platform::float16>(const int &m, const int &n, const int &k,
190193
return false;
191194
}
192195

193-
template <>
194196
template <typename T>
195-
void Blas<platform::CPUDeviceContext>::GEMM(CBLAS_TRANSPOSE transA,
196-
CBLAS_TRANSPOSE transB, int M,
197-
int N, int K, T alpha, const T *A,
198-
const T *B, T beta, T *C) const {
199-
int lda = (transA == CblasNoTrans) ? K : M;
200-
int ldb = (transB == CblasNoTrans) ? N : K;
201-
int ldc = N;
197+
inline void GEMM_WARP(CBLAS_ORDER order, CBLAS_TRANSPOSE transA,
198+
CBLAS_TRANSPOSE transB, int M, int N, int K, T alpha,
199+
const T *A, int lda, const T *B, int ldb, T beta, T *C,
200+
int ldc) {
202201
#ifdef PADDLE_WITH_LIBXSMM
203-
if (UseXSMM(M, N, K, transA != CblasNoTrans, transB != CblasNoTrans, alpha,
204-
beta)) {
202+
if (UseXSMM<T>(M, N, K, transA != CblasNoTrans, transB != CblasNoTrans, alpha,
203+
beta)) {
205204
// Note: SMM use ColMajor
206205
const char transa = 'N';
207206
const char transb = 'N';
208207
CBlas<T>::SMM_GEMM(&transa, &transb, &N, &M, &K, &alpha, B, &ldb, A, &lda,
209208
&beta, C, &ldc);
210-
} else {
209+
return;
210+
}
211211
#endif
212212

213213
#ifdef PADDLE_MKL_SPLIT_GEMM
214-
constexpr int bs = 2;
215-
if (M % bs == 0 && transA == CblasNoTrans && transB == CblasNoTrans) {
216-
for (int off = 0; off < M; off += bs) {
217-
CBlas<T>::GEMM(CblasRowMajor, CblasNoTrans, CblasNoTrans, off, N, K,
218-
alpha, A + off * lda, lda, B, ldb, beta, C + off * ldb,
219-
ldc);
220-
}
221-
} else {
222-
#endif
223-
CBlas<T>::GEMM(CblasRowMajor, transA, transB, M, N, K, alpha, A, lda, B,
224-
ldb, beta, C, ldc);
225-
#ifdef PADDLE_MKL_SPLIT_GEMM
214+
constexpr int bs = 2;
215+
if (M % bs == 0 && transA == CblasNoTrans && transB == CblasNoTrans) {
216+
for (int off = 0; off < M; off += bs) {
217+
CBlas<T>::GEMM(CblasRowMajor, CblasNoTrans, CblasNoTrans, bs, N, K, alpha,
218+
A + off * lda, lda, B, ldb, beta, C + off * ldb, ldc);
226219
}
227-
#endif
228-
229-
#ifdef PADDLE_WITH_LIBXSMM
220+
return;
230221
}
231222
#endif
223+
CBlas<T>::GEMM(CblasRowMajor, transA, transB, M, N, K, alpha, A, lda, B, ldb,
224+
beta, C, ldc);
225+
}
226+
227+
template <>
228+
template <typename T>
229+
void Blas<platform::CPUDeviceContext>::GEMM(CBLAS_TRANSPOSE transA,
230+
CBLAS_TRANSPOSE transB, int M,
231+
int N, int K, T alpha, const T *A,
232+
const T *B, T beta, T *C) const {
233+
int lda = (transA == CblasNoTrans) ? K : M;
234+
int ldb = (transB == CblasNoTrans) ? N : K;
235+
int ldc = N;
236+
GEMM_WARP<T>(CblasRowMajor, transA, transB, M, N, K, alpha, A, lda, B, ldb,
237+
beta, C, ldc);
232238
}
233239

234240
template <>
@@ -237,9 +243,9 @@ void Blas<platform::CPUDeviceContext>::GEMM(bool transA, bool transB, int M,
237243
int N, int K, T alpha, const T *A,
238244
int lda, const T *B, int ldb,
239245
T beta, T *C, int ldc) const {
240-
CBlas<T>::GEMM(CblasRowMajor, transA == false ? CblasNoTrans : CblasTrans,
241-
transB == false ? CblasNoTrans : CblasTrans, M, N, K, alpha, A,
242-
lda, B, ldb, beta, C, ldc);
246+
GEMM_WARP<T>(CblasRowMajor, transA == false ? CblasNoTrans : CblasTrans,
247+
transB == false ? CblasNoTrans : CblasTrans, M, N, K, alpha, A,
248+
lda, B, ldb, beta, C, ldc);
243249
}
244250

245251
template <typename DeviceContext>

0 commit comments

Comments
 (0)