@@ -14,6 +14,7 @@ limitations under the License. */
14
14
15
15
#pragma once
16
16
#include " paddle/fluid/operators/elementwise_op_function.h"
17
+ #include " paddle/fluid/operators/math/blas.h"
17
18
18
19
namespace paddle {
19
20
namespace operators {
@@ -23,6 +24,37 @@ struct MulFunctor {
23
24
inline HOSTDEVICE T operator ()(T a, T b) const { return a * b; }
24
25
};
25
26
27
+ template <typename DeviceContext, typename T>
28
+ void default_elementwise_mul (const framework::ExecutionContext& ctx,
29
+ const framework::Tensor* x,
30
+ const framework::Tensor* y, framework::Tensor* z) {
31
+ int axis = ctx.Attr <int >(" axis" );
32
+ ElementwiseComputeEx<MulFunctor<T>, DeviceContext, T>(ctx, x, y, axis,
33
+ MulFunctor<T>(), z);
34
+ }
35
+
36
+ template <typename DeviceContext, typename T>
37
+ typename std::enable_if<
38
+ std::is_floating_point<T>::value &&
39
+ std::is_same<DeviceContext, platform::CPUDeviceContext>::value>::type
40
+ elementwise_mul (const framework::ExecutionContext& ctx,
41
+ const framework::Tensor* x, const framework::Tensor* y,
42
+ framework::Tensor* z) {
43
+ auto blas = math::GetBlas<DeviceContext, T>(ctx);
44
+ blas.VMUL (x->numel (), x->data <T>(), y->data <T>(),
45
+ z->mutable_data <T>(ctx.GetPlace ()));
46
+ }
47
+
48
+ template <typename DeviceContext, typename T>
49
+ typename std::enable_if<
50
+ !std::is_floating_point<T>::value ||
51
+ !std::is_same<DeviceContext, platform::CPUDeviceContext>::value>::type
52
+ elementwise_mul (const framework::ExecutionContext& ctx,
53
+ const framework::Tensor* x, const framework::Tensor* y,
54
+ framework::Tensor* z) {
55
+ default_elementwise_mul<DeviceContext, T>(ctx, x, y, z);
56
+ }
57
+
26
58
template <typename DeviceContext, typename T>
27
59
class ElementwiseMulKernel : public framework ::OpKernel<T> {
28
60
public:
@@ -33,9 +65,11 @@ class ElementwiseMulKernel : public framework::OpKernel<T> {
33
65
auto * y = ctx.Input <Tensor>(" Y" );
34
66
auto * z = ctx.Output <Tensor>(" Out" );
35
67
z->mutable_data <T>(ctx.GetPlace ());
36
- int axis = ctx.Attr <int >(" axis" );
37
- ElementwiseComputeEx<MulFunctor<T>, DeviceContext, T>(ctx, x, y, axis,
38
- MulFunctor<T>(), z);
68
+ if (x->numel () == y->numel ()) {
69
+ elementwise_mul<DeviceContext, T>(ctx, x, y, z);
70
+ } else {
71
+ default_elementwise_mul<DeviceContext, T>(ctx, x, y, z);
72
+ }
39
73
}
40
74
};
41
75
0 commit comments