Skip to content

Commit a56142c

Browse files
committed
optimize elementwise_mul cpu forward
1 parent 6644ce7 commit a56142c

File tree

1 file changed

+37
-3
lines changed

1 file changed

+37
-3
lines changed

paddle/fluid/operators/elementwise_mul_op.h

Lines changed: 37 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@ limitations under the License. */
1414

1515
#pragma once
1616
#include "paddle/fluid/operators/elementwise_op_function.h"
17+
#include "paddle/fluid/operators/math/blas.h"
1718

1819
namespace paddle {
1920
namespace operators {
@@ -23,6 +24,37 @@ struct MulFunctor {
2324
inline HOSTDEVICE T operator()(T a, T b) const { return a * b; }
2425
};
2526

27+
template <typename DeviceContext, typename T>
28+
void default_elementwise_mul(const framework::ExecutionContext& ctx,
29+
const framework::Tensor* x,
30+
const framework::Tensor* y, framework::Tensor* z) {
31+
int axis = ctx.Attr<int>("axis");
32+
ElementwiseComputeEx<MulFunctor<T>, DeviceContext, T>(ctx, x, y, axis,
33+
MulFunctor<T>(), z);
34+
}
35+
36+
template <typename DeviceContext, typename T>
37+
typename std::enable_if<
38+
std::is_floating_point<T>::value &&
39+
std::is_same<DeviceContext, platform::CPUDeviceContext>::value>::type
40+
elementwise_mul(const framework::ExecutionContext& ctx,
41+
const framework::Tensor* x, const framework::Tensor* y,
42+
framework::Tensor* z) {
43+
auto blas = math::GetBlas<DeviceContext, T>(ctx);
44+
blas.VMUL(x->numel(), x->data<T>(), y->data<T>(),
45+
z->mutable_data<T>(ctx.GetPlace()));
46+
}
47+
48+
template <typename DeviceContext, typename T>
49+
typename std::enable_if<
50+
!std::is_floating_point<T>::value ||
51+
!std::is_same<DeviceContext, platform::CPUDeviceContext>::value>::type
52+
elementwise_mul(const framework::ExecutionContext& ctx,
53+
const framework::Tensor* x, const framework::Tensor* y,
54+
framework::Tensor* z) {
55+
default_elementwise_mul<DeviceContext, T>(ctx, x, y, z);
56+
}
57+
2658
template <typename DeviceContext, typename T>
2759
class ElementwiseMulKernel : public framework::OpKernel<T> {
2860
public:
@@ -33,9 +65,11 @@ class ElementwiseMulKernel : public framework::OpKernel<T> {
3365
auto* y = ctx.Input<Tensor>("Y");
3466
auto* z = ctx.Output<Tensor>("Out");
3567
z->mutable_data<T>(ctx.GetPlace());
36-
int axis = ctx.Attr<int>("axis");
37-
ElementwiseComputeEx<MulFunctor<T>, DeviceContext, T>(ctx, x, y, axis,
38-
MulFunctor<T>(), z);
68+
if (x->numel() == y->numel()) {
69+
elementwise_mul<DeviceContext, T>(ctx, x, y, z);
70+
} else {
71+
default_elementwise_mul<DeviceContext, T>(ctx, x, y, z);
72+
}
3973
}
4074
};
4175

0 commit comments

Comments
 (0)