Skip to content

Commit be04fbf

Browse files
authored
Merge pull request #12233 from tensor-tang/refine/mkl/gemm
add option split mkl gemm
2 parents 7219676 + fc2b578 commit be04fbf

File tree

3 files changed

+98
-17
lines changed

3 files changed

+98
-17
lines changed

CMakeLists.txt

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -136,6 +136,12 @@ else()
136136
set(THIRD_PARTY_BUILD_TYPE Release)
137137
endif()
138138

139+
if(WITH_MKL)
140+
option(MKL_SPLIT_GEMM "PaddlePaddle MKL gemm would split to small ones" OFF)
141+
if (MKL_SPLIT_GEMM)
142+
add_definitions(-DPADDLE_MKL_SPLIT_GEMM)
143+
endif()
144+
endif()
139145
set(WITH_MKLML ${WITH_MKL})
140146
if (NOT DEFINED WITH_MKLDNN)
141147
if (WITH_MKL AND AVX2_FOUND)

paddle/fluid/operators/math/blas_impl.h

Lines changed: 38 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,7 @@ struct CBlas<float> {
3737
libxsmm_sgemm(args...);
3838
}
3939
#endif
40+
4041
template <typename... ARGS>
4142
static void AXPY(ARGS... args) {
4243
platform::dynload::cblas_saxpy(args...);
@@ -76,6 +77,7 @@ struct CBlas<double> {
7677
libxsmm_dgemm(args...);
7778
}
7879
#endif
80+
7981
template <typename... ARGS>
8082
static void AXPY(ARGS... args) {
8183
platform::dynload::cblas_daxpy(args...);
@@ -150,6 +152,7 @@ struct CBlas<double> {
150152
}
151153
};
152154
#endif
155+
153156
template <>
154157
struct CBlas<platform::float16> {
155158
static void GEMM(...) { PADDLE_THROW("float16 GEMM not supported on CPU"); }
@@ -190,30 +193,48 @@ inline bool UseXSMM<platform::float16>(const int &m, const int &n, const int &k,
190193
return false;
191194
}
192195

193-
template <>
194196
template <typename T>
195-
void Blas<platform::CPUDeviceContext>::GEMM(CBLAS_TRANSPOSE transA,
196-
CBLAS_TRANSPOSE transB, int M,
197-
int N, int K, T alpha, const T *A,
198-
const T *B, T beta, T *C) const {
199-
int lda = (transA == CblasNoTrans) ? K : M;
200-
int ldb = (transB == CblasNoTrans) ? N : K;
201-
int ldc = N;
197+
inline void GEMM_WARP(CBLAS_ORDER order, CBLAS_TRANSPOSE transA,
198+
CBLAS_TRANSPOSE transB, int M, int N, int K, T alpha,
199+
const T *A, int lda, const T *B, int ldb, T beta, T *C,
200+
int ldc) {
202201
#ifdef PADDLE_WITH_LIBXSMM
203-
if (UseXSMM(M, N, K, transA != CblasNoTrans, transB != CblasNoTrans, alpha,
204-
beta)) {
202+
if (UseXSMM<T>(M, N, K, transA != CblasNoTrans, transB != CblasNoTrans, alpha,
203+
beta)) {
205204
// Note: SMM use ColMajor
206205
const char transa = 'N';
207206
const char transb = 'N';
208207
CBlas<T>::SMM_GEMM(&transa, &transb, &N, &M, &K, &alpha, B, &ldb, A, &lda,
209208
&beta, C, &ldc);
210-
} else {
209+
return;
210+
}
211211
#endif
212-
CBlas<T>::GEMM(CblasRowMajor, transA, transB, M, N, K, alpha, A, lda, B,
213-
ldb, beta, C, ldc);
214-
#ifdef PADDLE_WITH_LIBXSMM
212+
213+
#ifdef PADDLE_MKL_SPLIT_GEMM
214+
constexpr int bs = 2;
215+
if (M % bs == 0 && transA == CblasNoTrans && transB == CblasNoTrans) {
216+
for (int off = 0; off < M; off += bs) {
217+
CBlas<T>::GEMM(CblasRowMajor, CblasNoTrans, CblasNoTrans, bs, N, K, alpha,
218+
A + off * lda, lda, B, ldb, beta, C + off * ldb, ldc);
219+
}
220+
return;
215221
}
216222
#endif
223+
CBlas<T>::GEMM(CblasRowMajor, transA, transB, M, N, K, alpha, A, lda, B, ldb,
224+
beta, C, ldc);
225+
}
226+
227+
template <>
228+
template <typename T>
229+
void Blas<platform::CPUDeviceContext>::GEMM(CBLAS_TRANSPOSE transA,
230+
CBLAS_TRANSPOSE transB, int M,
231+
int N, int K, T alpha, const T *A,
232+
const T *B, T beta, T *C) const {
233+
int lda = (transA == CblasNoTrans) ? K : M;
234+
int ldb = (transB == CblasNoTrans) ? N : K;
235+
int ldc = N;
236+
GEMM_WARP<T>(CblasRowMajor, transA, transB, M, N, K, alpha, A, lda, B, ldb,
237+
beta, C, ldc);
217238
}
218239

219240
template <>
@@ -222,9 +243,9 @@ void Blas<platform::CPUDeviceContext>::GEMM(bool transA, bool transB, int M,
222243
int N, int K, T alpha, const T *A,
223244
int lda, const T *B, int ldb,
224245
T beta, T *C, int ldc) const {
225-
CBlas<T>::GEMM(CblasRowMajor, transA == false ? CblasNoTrans : CblasTrans,
226-
transB == false ? CblasNoTrans : CblasTrans, M, N, K, alpha, A,
227-
lda, B, ldb, beta, C, ldc);
246+
GEMM_WARP<T>(CblasRowMajor, transA == false ? CblasNoTrans : CblasTrans,
247+
transB == false ? CblasNoTrans : CblasTrans, M, N, K, alpha, A,
248+
lda, B, ldb, beta, C, ldc);
228249
}
229250

230251
template <typename DeviceContext>

paddle/fluid/operators/math/math_function_test.cc

Lines changed: 54 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -228,3 +228,57 @@ TEST(math_funciton, set_constant) {
228228
}
229229
delete ctx;
230230
}
231+
232+
template <typename T>
233+
void GemmWarpTest(int m, int n, int k, T alpha, T beta) {
234+
paddle::framework::Tensor mat_a;
235+
paddle::framework::Tensor mat_b;
236+
paddle::framework::Tensor mat_c_ref;
237+
paddle::framework::Tensor mat_c_mkl;
238+
auto* cpu_place = new paddle::platform::CPUPlace();
239+
240+
T* A = mat_a.mutable_data<T>({m, k}, *cpu_place);
241+
T* B = mat_b.mutable_data<T>({k, n}, *cpu_place);
242+
T* CREF = mat_c_ref.mutable_data<T>({m, n}, *cpu_place);
243+
T* CMKL = mat_c_mkl.mutable_data<T>({m, n}, *cpu_place);
244+
245+
ASSERT_EQ(mat_c_mkl.numel(), mat_c_ref.numel());
246+
for (int i = 0; i < mat_a.numel(); ++i) {
247+
A[i] = static_cast<T>(i);
248+
}
249+
for (int i = 0; i < mat_b.numel(); ++i) {
250+
B[i] = static_cast<T>(i + 1);
251+
}
252+
for (int i = 0; i < mat_c_ref.numel(); ++i) {
253+
CREF[i] = static_cast<T>(i + 2);
254+
CMKL[i] = CREF[i];
255+
}
256+
257+
// this would call gemm_warp
258+
paddle::platform::CPUDeviceContext context(*cpu_place);
259+
GetBlas<T>(context).GEMM(CblasNoTrans, CblasNoTrans, m, n, k, alpha, A, B,
260+
beta, CREF);
261+
262+
// lda,ldb,ldc follow RowMajor
263+
int lda = k;
264+
int ldb = n;
265+
int ldc = n;
266+
paddle::operators::math::CBlas<T>::GEMM(CblasRowMajor, CblasNoTrans,
267+
CblasNoTrans, m, n, k, alpha, A, lda,
268+
B, ldb, beta, CMKL, ldc);
269+
270+
for (int i = 0; i < mat_c_mkl.numel(); ++i) {
271+
EXPECT_FLOAT_EQ(CREF[i], CMKL[i]);
272+
}
273+
}
274+
275+
TEST(math_function, gemm_warp) {
276+
GemmWarpTest<float>(3, 2, 5, 1.f, 0.f);
277+
GemmWarpTest<float>(3, 2, 5, 2.f, 1.f);
278+
GemmWarpTest<float>(8, 5, 6, 1.f, 0.f);
279+
GemmWarpTest<float>(8, 5, 6, 2.f, 1.f);
280+
GemmWarpTest<double>(3, 2, 5, 1.0, 0.0);
281+
GemmWarpTest<double>(3, 2, 5, 2.0, 1.0);
282+
GemmWarpTest<double>(8, 5, 6, 1.0, 0.0);
283+
GemmWarpTest<double>(8, 5, 6, 2.0, 1.0);
284+
}

0 commit comments

Comments
 (0)