Skip to content

Commit d04ef27

Browse files
authored
Merge pull request #12745 from tensor-tang/refine/op/elewise_mul
Refine elementwise mul cpu forward
2 parents cbc6e6e + a56142c commit d04ef27

File tree

4 files changed

+82
-13
lines changed

4 files changed

+82
-13
lines changed

paddle/fluid/operators/elementwise_mul_op.h

Lines changed: 37 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@ limitations under the License. */
1414

1515
#pragma once
1616
#include "paddle/fluid/operators/elementwise_op_function.h"
17+
#include "paddle/fluid/operators/math/blas.h"
1718

1819
namespace paddle {
1920
namespace operators {
@@ -23,6 +24,37 @@ struct MulFunctor {
2324
inline HOSTDEVICE T operator()(T a, T b) const { return a * b; }
2425
};
2526

27+
template <typename DeviceContext, typename T>
28+
void default_elementwise_mul(const framework::ExecutionContext& ctx,
29+
const framework::Tensor* x,
30+
const framework::Tensor* y, framework::Tensor* z) {
31+
int axis = ctx.Attr<int>("axis");
32+
ElementwiseComputeEx<MulFunctor<T>, DeviceContext, T>(ctx, x, y, axis,
33+
MulFunctor<T>(), z);
34+
}
35+
36+
template <typename DeviceContext, typename T>
37+
typename std::enable_if<
38+
std::is_floating_point<T>::value &&
39+
std::is_same<DeviceContext, platform::CPUDeviceContext>::value>::type
40+
elementwise_mul(const framework::ExecutionContext& ctx,
41+
const framework::Tensor* x, const framework::Tensor* y,
42+
framework::Tensor* z) {
43+
auto blas = math::GetBlas<DeviceContext, T>(ctx);
44+
blas.VMUL(x->numel(), x->data<T>(), y->data<T>(),
45+
z->mutable_data<T>(ctx.GetPlace()));
46+
}
47+
48+
template <typename DeviceContext, typename T>
49+
typename std::enable_if<
50+
!std::is_floating_point<T>::value ||
51+
!std::is_same<DeviceContext, platform::CPUDeviceContext>::value>::type
52+
elementwise_mul(const framework::ExecutionContext& ctx,
53+
const framework::Tensor* x, const framework::Tensor* y,
54+
framework::Tensor* z) {
55+
default_elementwise_mul<DeviceContext, T>(ctx, x, y, z);
56+
}
57+
2658
template <typename DeviceContext, typename T>
2759
class ElementwiseMulKernel : public framework::OpKernel<T> {
2860
public:
@@ -33,9 +65,11 @@ class ElementwiseMulKernel : public framework::OpKernel<T> {
3365
auto* y = ctx.Input<Tensor>("Y");
3466
auto* z = ctx.Output<Tensor>("Out");
3567
z->mutable_data<T>(ctx.GetPlace());
36-
int axis = ctx.Attr<int>("axis");
37-
ElementwiseComputeEx<MulFunctor<T>, DeviceContext, T>(ctx, x, y, axis,
38-
MulFunctor<T>(), z);
68+
if (x->numel() == y->numel()) {
69+
elementwise_mul<DeviceContext, T>(ctx, x, y, z);
70+
} else {
71+
default_elementwise_mul<DeviceContext, T>(ctx, x, y, z);
72+
}
3973
}
4074
};
4175

paddle/fluid/operators/math/blas.h

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -134,6 +134,9 @@ class Blas {
134134
template <typename T>
135135
void VADD(int n, const T* x, const T* y, T* z) const;
136136

137+
template <typename T>
138+
void VMUL(int n, const T* x, const T* y, T* z) const;
139+
137140
template <typename T>
138141
void VCOPY(int n, const T* x, T* y) const;
139142

@@ -202,6 +205,11 @@ class BlasT : private Blas<DeviceContext> {
202205
Base()->template VADD<T>(args...);
203206
}
204207

208+
template <typename... ARGS>
209+
void VMUL(ARGS... args) const {
210+
Base()->template VMUL<T>(args...);
211+
}
212+
205213
template <typename... ARGS>
206214
void VCOPY(ARGS... args) const {
207215
Base()->template VCOPY<T>(args...);

paddle/fluid/operators/math/blas_impl.h

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -82,6 +82,11 @@ struct CBlas<float> {
8282
static void VADD(ARGS... args) {
8383
platform::dynload::vsAdd(args...);
8484
}
85+
86+
template <typename... ARGS>
87+
static void VMUL(ARGS... args) {
88+
platform::dynload::vsMul(args...);
89+
}
8590
};
8691

8792
template <>
@@ -142,6 +147,11 @@ struct CBlas<double> {
142147
static void VADD(ARGS... args) {
143148
platform::dynload::vdAdd(args...);
144149
}
150+
151+
template <typename... ARGS>
152+
static void VMUL(ARGS... args) {
153+
platform::dynload::vdMul(args...);
154+
}
145155
};
146156

147157
#else
@@ -199,6 +209,7 @@ struct CBlas<platform::float16> {
199209
static void SMM_GEMM(...) {
200210
PADDLE_THROW("float16 SMM_GEMM not supported on CPU");
201211
}
212+
static void VMUL(...) { PADDLE_THROW("float16 VMUL not supported on CPU"); }
202213
#ifdef PADDLE_WITH_MKLML
203214
static void GEMM_BATCH(...) {
204215
PADDLE_THROW("float16 GEMM_BATCH not supported on CPU");
@@ -374,6 +385,20 @@ void Blas<platform::CPUDeviceContext>::VADD(int n, const T *x, const T *y,
374385
#endif
375386
}
376387

388+
template <>
389+
template <typename T>
390+
void Blas<platform::CPUDeviceContext>::VMUL(int n, const T *x, const T *y,
391+
T *z) const {
392+
#ifdef PADDLE_WITH_MKLML
393+
CBlas<T>::VMUL(n, x, y, z);
394+
#else
395+
// try to find if openblas support vmul
396+
for (int i = 0; i < n; ++i) {
397+
z[i] = x[i] * y[i];
398+
}
399+
#endif
400+
}
401+
377402
template <>
378403
template <typename T>
379404
void Blas<platform::CPUDeviceContext>::GEMV(bool trans_a, int M, int N, T alpha,

paddle/fluid/platform/dynload/mklml.h

Lines changed: 12 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -49,25 +49,27 @@ extern void* mklml_dso_handle;
4949

5050
#define MKLML_ROUTINE_EACH(__macro) \
5151
__macro(cblas_sgemm); \
52-
__macro(cblas_saxpy); \
53-
__macro(cblas_scopy); \
54-
__macro(cblas_sgemv); \
55-
__macro(cblas_sgemm_batch); \
5652
__macro(cblas_dgemm); \
53+
__macro(cblas_saxpy); \
5754
__macro(cblas_daxpy); \
55+
__macro(cblas_scopy); \
5856
__macro(cblas_dcopy); \
57+
__macro(cblas_sgemv); \
5958
__macro(cblas_dgemv); \
60-
__macro(cblas_dgemm_batch); \
61-
__macro(vsAdd); \
62-
__macro(vdAdd); \
6359
__macro(cblas_sgemm_alloc); \
64-
__macro(cblas_sgemm_pack); \
65-
__macro(cblas_sgemm_compute); \
66-
__macro(cblas_sgemm_free); \
6760
__macro(cblas_dgemm_alloc); \
61+
__macro(cblas_sgemm_pack); \
6862
__macro(cblas_dgemm_pack); \
63+
__macro(cblas_sgemm_compute); \
6964
__macro(cblas_dgemm_compute); \
65+
__macro(cblas_sgemm_free); \
7066
__macro(cblas_dgemm_free); \
67+
__macro(cblas_sgemm_batch); \
68+
__macro(cblas_dgemm_batch); \
69+
__macro(vsAdd); \
70+
__macro(vdAdd); \
71+
__macro(vsMul); \
72+
__macro(vdMul); \
7173
__macro(MKL_Set_Num_Threads)
7274

7375
MKLML_ROUTINE_EACH(DECLARE_DYNAMIC_LOAD_MKLML_WRAP);

0 commit comments

Comments
 (0)