Skip to content

Commit 43cee33

Browse files
committed
add mkl packed gemm
1 parent 0964de1 commit 43cee33

File tree

3 files changed

+118
-0
lines changed

3 files changed

+118
-0
lines changed

paddle/fluid/operators/math/blas.h

Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -90,6 +90,23 @@ class Blas {
9090
void GEMM(bool transA, bool transB, int M, int N, int K, T alpha, const T* A,
9191
int lda, const T* B, int ldb, T beta, T* C, int ldc) const;
9292

93+
template <typename T>
94+
T* GEMM_ALLOC(const CBLAS_IDENTIFIER id, const int M, const int N,
95+
const int K) const;
96+
97+
template <typename T>
98+
void GEMM_PACK(const CBLAS_IDENTIFIER id, const CBLAS_TRANSPOSE trans, int M,
99+
int N, int K, const T alpha, const T* src, const int ld,
100+
T* dst) const;
101+
102+
template <typename T>
103+
void GEMM_COMPUTE(int transA, int transB, int M, int N, int K, const T* A,
104+
const int lda, const T* B, const int ldb, T beta, T* C,
105+
const int ldc) const;
106+
107+
template <typename T>
108+
void GEMM_FREE(T* data) const;
109+
93110
template <typename T>
94111
void MatMul(const framework::Tensor& mat_a, bool trans_a,
95112
const framework::Tensor& mat_b, bool trans_b, T alpha,
@@ -146,6 +163,26 @@ class BlasT : private Blas<DeviceContext> {
146163
Base()->template GEMM<T>(args...);
147164
}
148165

166+
template <typename... ARGS>
167+
T* GEMM_ALLOC(ARGS... args) const {
168+
Base()->template GEMM_ALLOC<T>(args...);
169+
}
170+
171+
template <typename... ARGS>
172+
void GEMM_PACK(ARGS... args) const {
173+
Base()->template GEMM_PACK<T>(args...);
174+
}
175+
176+
template <typename... ARGS>
177+
void GEMM_COMPUTE(ARGS... args) const {
178+
Base()->template GEMM_COMPUTE<T>(args...);
179+
}
180+
181+
template <typename... ARGS>
182+
void GEMM_FREE(ARGS... args) const {
183+
Base()->template GEMM_FREE<T>(args...);
184+
}
185+
149186
template <typename... ARGS>
150187
void MatMul(ARGS... args) const {
151188
Base()->template MatMul<T>(args...);

paddle/fluid/operators/math/blas_impl.h

Lines changed: 73 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,26 @@ struct CBlas<float> {
3131
platform::dynload::cblas_sgemm(args...);
3232
}
3333

34+
template <typename... ARGS>
35+
static float *GEMM_ALLOC(ARGS... args) {
36+
return platform::dynload::cblas_sgemm_alloc(args...);
37+
}
38+
39+
template <typename... ARGS>
40+
static void GEMM_PACK(ARGS... args) {
41+
platform::dynload::cblas_sgemm_pack(args...);
42+
}
43+
44+
template <typename... ARGS>
45+
static void GEMM_COMPUTE(ARGS... args) {
46+
platform::dynload::cblas_sgemm_compute(args...);
47+
}
48+
49+
template <typename... ARGS>
50+
static void GEMM_FREE(ARGS... args) {
51+
platform::dynload::cblas_sgemm_free(args...);
52+
}
53+
3454
#ifdef PADDLE_WITH_LIBXSMM
3555
template <typename... ARGS>
3656
static void SMM_GEMM(ARGS... args) {
@@ -71,6 +91,26 @@ struct CBlas<double> {
7191
platform::dynload::cblas_dgemm(args...);
7292
}
7393

94+
template <typename... ARGS>
95+
static double *GEMM_ALLOC(ARGS... args) {
96+
return platform::dynload::cblas_dgemm_alloc(args...);
97+
}
98+
99+
template <typename... ARGS>
100+
static void GEMM_PACK(ARGS... args) {
101+
platform::dynload::cblas_dgemm_pack(args...);
102+
}
103+
104+
template <typename... ARGS>
105+
static void GEMM_COMPUTE(ARGS... args) {
106+
platform::dynload::cblas_dgemm_compute(args...);
107+
}
108+
109+
template <typename... ARGS>
110+
static void GEMM_FREE(ARGS... args) {
111+
platform::dynload::cblas_dgemm_free(args...);
112+
}
113+
74114
#ifdef PADDLE_WITH_LIBXSMM
75115
template <typename... ARGS>
76116
static void SMM_GEMM(ARGS... args) {
@@ -224,6 +264,39 @@ inline void GEMM_WARP(CBLAS_ORDER order, CBLAS_TRANSPOSE transA,
224264
beta, C, ldc);
225265
}
226266

267+
template <>
268+
template <typename T>
269+
T *Blas<platform::CPUDeviceContext>::GEMM_ALLOC(const CBLAS_IDENTIFIER id,
270+
const int M, const int N,
271+
const int K) const {
272+
return CBlas<T>::GEMM_ALLOC(id, M, N, K);
273+
}
274+
275+
template <>
276+
template <typename T>
277+
void Blas<platform::CPUDeviceContext>::GEMM_PACK(const CBLAS_IDENTIFIER id,
278+
const CBLAS_TRANSPOSE trans,
279+
int M, int N, int K,
280+
const T alpha, const T *src,
281+
const int ld, T *dst) const {
282+
CBlas<T>::GEMM_PACK(CblasRowMajor, id, trans, M, N, K, alpha, src, ld, dst);
283+
}
284+
285+
template <>
286+
template <typename T>
287+
void Blas<platform::CPUDeviceContext>::GEMM_COMPUTE(
288+
int transA, int transB, int M, int N, int K, const T *A, const int lda,
289+
const T *B, const int ldb, T beta, T *C, const int ldc) const {
290+
CBlas<T>::GEMM_COMPUTE(CblasRowMajor, transA, transB, M, N, K, A, lda, B, ldb,
291+
beta, C, ldc);
292+
}
293+
294+
template <>
295+
template <typename T>
296+
void Blas<platform::CPUDeviceContext>::GEMM_FREE(T *data) const {
297+
CBlas<T>::GEMM_FREE(data);
298+
}
299+
227300
template <>
228301
template <typename T>
229302
void Blas<platform::CPUDeviceContext>::GEMM(CBLAS_TRANSPOSE transA,

paddle/fluid/platform/dynload/mklml.h

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -60,6 +60,14 @@ extern void* mklml_dso_handle;
6060
__macro(cblas_dgemm_batch); \
6161
__macro(vsAdd); \
6262
__macro(vdAdd); \
63+
__macro(cblas_sgemm_alloc); \
64+
__macro(cblas_sgemm_pack); \
65+
__macro(cblas_sgemm_compute); \
66+
__macro(cblas_sgemm_free); \
67+
__macro(cblas_dgemm_alloc); \
68+
__macro(cblas_dgemm_pack); \
69+
__macro(cblas_dgemm_compute); \
70+
__macro(cblas_dgemm_free); \
6371
__macro(MKL_Set_Num_Threads)
6472

6573
MKLML_ROUTINE_EACH(DECLARE_DYNAMIC_LOAD_MKLML_WRAP);

0 commit comments

Comments
 (0)