@@ -31,6 +31,26 @@ struct CBlas<float> {
31
31
platform::dynload::cblas_sgemm (args...);
32
32
}
33
33
34
+ template <typename ... ARGS>
35
+ static float *GEMM_ALLOC (ARGS... args) {
36
+ return platform::dynload::cblas_sgemm_alloc (args...);
37
+ }
38
+
39
+ template <typename ... ARGS>
40
+ static void GEMM_PACK (ARGS... args) {
41
+ platform::dynload::cblas_sgemm_pack (args...);
42
+ }
43
+
44
+ template <typename ... ARGS>
45
+ static void GEMM_COMPUTE (ARGS... args) {
46
+ platform::dynload::cblas_sgemm_compute (args...);
47
+ }
48
+
49
+ template <typename ... ARGS>
50
+ static void GEMM_FREE (ARGS... args) {
51
+ platform::dynload::cblas_sgemm_free (args...);
52
+ }
53
+
34
54
#ifdef PADDLE_WITH_LIBXSMM
35
55
template <typename ... ARGS>
36
56
static void SMM_GEMM (ARGS... args) {
@@ -71,6 +91,26 @@ struct CBlas<double> {
71
91
platform::dynload::cblas_dgemm (args...);
72
92
}
73
93
94
+ template <typename ... ARGS>
95
+ static double *GEMM_ALLOC (ARGS... args) {
96
+ return platform::dynload::cblas_dgemm_alloc (args...);
97
+ }
98
+
99
+ template <typename ... ARGS>
100
+ static void GEMM_PACK (ARGS... args) {
101
+ platform::dynload::cblas_dgemm_pack (args...);
102
+ }
103
+
104
+ template <typename ... ARGS>
105
+ static void GEMM_COMPUTE (ARGS... args) {
106
+ platform::dynload::cblas_dgemm_compute (args...);
107
+ }
108
+
109
+ template <typename ... ARGS>
110
+ static void GEMM_FREE (ARGS... args) {
111
+ platform::dynload::cblas_dgemm_free (args...);
112
+ }
113
+
74
114
#ifdef PADDLE_WITH_LIBXSMM
75
115
template <typename ... ARGS>
76
116
static void SMM_GEMM (ARGS... args) {
@@ -224,6 +264,39 @@ inline void GEMM_WARP(CBLAS_ORDER order, CBLAS_TRANSPOSE transA,
224
264
beta, C, ldc);
225
265
}
226
266
267
+ template <>
268
+ template <typename T>
269
+ T *Blas<platform::CPUDeviceContext>::GEMM_ALLOC(const CBLAS_IDENTIFIER id,
270
+ const int M, const int N,
271
+ const int K) const {
272
+ return CBlas<T>::GEMM_ALLOC (id, M, N, K);
273
+ }
274
+
275
+ template <>
276
+ template <typename T>
277
+ void Blas<platform::CPUDeviceContext>::GEMM_PACK(const CBLAS_IDENTIFIER id,
278
+ const CBLAS_TRANSPOSE trans,
279
+ int M, int N, int K,
280
+ const T alpha, const T *src,
281
+ const int ld, T *dst) const {
282
+ CBlas<T>::GEMM_PACK (CblasRowMajor, id, trans, M, N, K, alpha, src, ld, dst);
283
+ }
284
+
285
+ template <>
286
+ template <typename T>
287
+ void Blas<platform::CPUDeviceContext>::GEMM_COMPUTE(
288
+ int transA, int transB, int M, int N, int K, const T *A, const int lda,
289
+ const T *B, const int ldb, T beta, T *C, const int ldc) const {
290
+ CBlas<T>::GEMM_COMPUTE (CblasRowMajor, transA, transB, M, N, K, A, lda, B, ldb,
291
+ beta, C, ldc);
292
+ }
293
+
294
+ template <>
295
+ template <typename T>
296
+ void Blas<platform::CPUDeviceContext>::GEMM_FREE(T *data) const {
297
+ CBlas<T>::GEMM_FREE (data);
298
+ }
299
+
227
300
template <>
228
301
template <typename T>
229
302
void Blas<platform::CPUDeviceContext>::GEMM(CBLAS_TRANSPOSE transA,
0 commit comments