Skip to content

Commit 6644ce7

Browse files
committed
add mklml vmul
1 parent ff92b6b commit 6644ce7

File tree

3 files changed

+45
-10
lines changed

3 files changed

+45
-10
lines changed

paddle/fluid/operators/math/blas.h

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -134,6 +134,9 @@ class Blas {
134134
template <typename T>
135135
void VADD(int n, const T* x, const T* y, T* z) const;
136136

137+
template <typename T>
138+
void VMUL(int n, const T* x, const T* y, T* z) const;
139+
137140
template <typename T>
138141
void VCOPY(int n, const T* x, T* y) const;
139142

@@ -202,6 +205,11 @@ class BlasT : private Blas<DeviceContext> {
202205
Base()->template VADD<T>(args...);
203206
}
204207

208+
template <typename... ARGS>
209+
void VMUL(ARGS... args) const {
210+
Base()->template VMUL<T>(args...);
211+
}
212+
205213
template <typename... ARGS>
206214
void VCOPY(ARGS... args) const {
207215
Base()->template VCOPY<T>(args...);

paddle/fluid/operators/math/blas_impl.h

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -82,6 +82,11 @@ struct CBlas<float> {
8282
static void VADD(ARGS... args) {
8383
platform::dynload::vsAdd(args...);
8484
}
85+
86+
template <typename... ARGS>
87+
static void VMUL(ARGS... args) {
88+
platform::dynload::vsMul(args...);
89+
}
8590
};
8691

8792
template <>
@@ -142,6 +147,11 @@ struct CBlas<double> {
142147
static void VADD(ARGS... args) {
143148
platform::dynload::vdAdd(args...);
144149
}
150+
151+
template <typename... ARGS>
152+
static void VMUL(ARGS... args) {
153+
platform::dynload::vdMul(args...);
154+
}
145155
};
146156

147157
#else
@@ -199,6 +209,7 @@ struct CBlas<platform::float16> {
199209
static void SMM_GEMM(...) {
200210
PADDLE_THROW("float16 SMM_GEMM not supported on CPU");
201211
}
212+
static void VMUL(...) { PADDLE_THROW("float16 VMUL not supported on CPU"); }
202213
#ifdef PADDLE_WITH_MKLML
203214
static void GEMM_BATCH(...) {
204215
PADDLE_THROW("float16 GEMM_BATCH not supported on CPU");
@@ -374,6 +385,20 @@ void Blas<platform::CPUDeviceContext>::VADD(int n, const T *x, const T *y,
374385
#endif
375386
}
376387

388+
template <>
389+
template <typename T>
390+
void Blas<platform::CPUDeviceContext>::VMUL(int n, const T *x, const T *y,
391+
T *z) const {
392+
#ifdef PADDLE_WITH_MKLML
393+
CBlas<T>::VMUL(n, x, y, z);
394+
#else
395+
// try to find if openblas support vmul
396+
for (int i = 0; i < n; ++i) {
397+
z[i] = x[i] * y[i];
398+
}
399+
#endif
400+
}
401+
377402
template <>
378403
template <typename T>
379404
void Blas<platform::CPUDeviceContext>::GEMV(bool trans_a, int M, int N, T alpha,

paddle/fluid/platform/dynload/mklml.h

Lines changed: 12 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -49,25 +49,27 @@ extern void* mklml_dso_handle;
4949

5050
#define MKLML_ROUTINE_EACH(__macro) \
5151
__macro(cblas_sgemm); \
52-
__macro(cblas_saxpy); \
53-
__macro(cblas_scopy); \
54-
__macro(cblas_sgemv); \
55-
__macro(cblas_sgemm_batch); \
5652
__macro(cblas_dgemm); \
53+
__macro(cblas_saxpy); \
5754
__macro(cblas_daxpy); \
55+
__macro(cblas_scopy); \
5856
__macro(cblas_dcopy); \
57+
__macro(cblas_sgemv); \
5958
__macro(cblas_dgemv); \
60-
__macro(cblas_dgemm_batch); \
61-
__macro(vsAdd); \
62-
__macro(vdAdd); \
6359
__macro(cblas_sgemm_alloc); \
64-
__macro(cblas_sgemm_pack); \
65-
__macro(cblas_sgemm_compute); \
66-
__macro(cblas_sgemm_free); \
6760
__macro(cblas_dgemm_alloc); \
61+
__macro(cblas_sgemm_pack); \
6862
__macro(cblas_dgemm_pack); \
63+
__macro(cblas_sgemm_compute); \
6964
__macro(cblas_dgemm_compute); \
65+
__macro(cblas_sgemm_free); \
7066
__macro(cblas_dgemm_free); \
67+
__macro(cblas_sgemm_batch); \
68+
__macro(cblas_dgemm_batch); \
69+
__macro(vsAdd); \
70+
__macro(vdAdd); \
71+
__macro(vsMul); \
72+
__macro(vdMul); \
7173
__macro(MKL_Set_Num_Threads)
7274

7375
MKLML_ROUTINE_EACH(DECLARE_DYNAMIC_LOAD_MKLML_WRAP);

0 commit comments

Comments
 (0)