@@ -14,7 +14,9 @@ limitations under the License. */
14
14
15
15
#pragma once
16
16
17
+ #include " paddle/fluid/framework/eigen.h"
17
18
#include " paddle/fluid/operators/elementwise_op_function.h"
19
+ #include " paddle/fluid/operators/math/blas.h"
18
20
19
21
namespace paddle {
20
22
namespace operators {
@@ -24,19 +26,57 @@ struct AddFunctor {
24
26
inline HOSTDEVICE T operator ()(T a, T b) const { return a + b; }
25
27
};
26
28
29
+ template <typename DeviceContext, typename T>
30
+ void default_elementwise_add (const framework::ExecutionContext& ctx,
31
+ const framework::Tensor* x,
32
+ const framework::Tensor* y, framework::Tensor* z) {
33
+ int axis = ctx.Attr <int >(" axis" );
34
+ ElementwiseComputeEx<AddFunctor<T>, DeviceContext, T>(ctx, x, y, axis,
35
+ AddFunctor<T>(), z);
36
+ }
37
+
38
+ template <typename DeviceContext, typename T>
39
+ typename std::enable_if<
40
+ std::is_floating_point<T>::value &&
41
+ std::is_same<DeviceContext, platform::CPUDeviceContext>::value>::type
42
+ elementwise_add (const framework::ExecutionContext& ctx,
43
+ const framework::Tensor* x, const framework::Tensor* y,
44
+ framework::Tensor* z) {
45
+ auto eigen_x = framework::EigenVector<T>::Flatten (*x);
46
+ auto eigen_y = framework::EigenVector<T>::Flatten (*y);
47
+ auto eigen_z = framework::EigenVector<T>::Flatten (*z);
48
+
49
+ auto blas = math::GetBlas<DeviceContext, T>(ctx);
50
+ blas.VADD (x->numel (), eigen_x.data (), eigen_y.data (), eigen_z.data ());
51
+ }
52
+
53
+ template <typename DeviceContext, typename T>
54
+ typename std::enable_if<
55
+ !std::is_floating_point<T>::value ||
56
+ !std::is_same<DeviceContext, platform::CPUDeviceContext>::value>::type
57
+ elementwise_add (const framework::ExecutionContext& ctx,
58
+ const framework::Tensor* x, const framework::Tensor* y,
59
+ framework::Tensor* z) {
60
+ default_elementwise_add<DeviceContext, T>(ctx, x, y, z);
61
+ }
62
+
27
63
template <typename DeviceContext, typename T>
28
64
class ElementwiseAddKernel : public framework ::OpKernel<T> {
29
65
public:
30
66
void Compute (const framework::ExecutionContext& ctx) const override {
31
67
using Tensor = framework::Tensor;
32
68
33
- auto * x = ctx.Input <Tensor>(" X" );
34
- auto * y = ctx.Input <Tensor>(" Y" );
35
- auto * z = ctx.Output <Tensor>(" Out" );
69
+ const auto x = ctx.Input <Tensor>(" X" );
70
+ const auto y = ctx.Input <Tensor>(" Y" );
71
+ auto z = ctx.Output <Tensor>(" Out" );
36
72
z->mutable_data <T>(ctx.GetPlace ());
37
- int axis = ctx.Attr <int >(" axis" );
38
- ElementwiseComputeEx<AddFunctor<T>, DeviceContext, T>(ctx, x, y, axis,
39
- AddFunctor<T>(), z);
73
+
74
+ auto dims_equal = x->dims () == y->dims ();
75
+ if (dims_equal) {
76
+ elementwise_add<DeviceContext, T>(ctx, x, y, z);
77
+ } else {
78
+ default_elementwise_add<DeviceContext, T>(ctx, x, y, z);
79
+ }
40
80
}
41
81
};
42
82
@@ -45,6 +85,55 @@ struct IdentityGrad {
45
85
HOSTDEVICE T operator ()(T x, T y, T out, T dout) const { return dout; }
46
86
};
47
87
88
+ template <typename DeviceContext, typename T>
89
+ void default_elementwise_add_grad (const framework::ExecutionContext& ctx,
90
+ const framework::Tensor* x,
91
+ const framework::Tensor* y,
92
+ const framework::Tensor* out,
93
+ const framework::Tensor* dout,
94
+ framework::Tensor* dx,
95
+ framework::Tensor* dy) {
96
+ int axis = ctx.Attr <int >(" axis" );
97
+
98
+ ElemwiseGradCompute<DeviceContext, T, IdentityGrad<T>, IdentityGrad<T>>(
99
+ ctx, *x, *y, *out, *dout, axis, dx, dy, IdentityGrad<T>(),
100
+ IdentityGrad<T>());
101
+ }
102
+
103
+ template <typename DeviceContext, typename T>
104
+ typename std::enable_if<
105
+ std::is_floating_point<T>::value &&
106
+ std::is_same<DeviceContext, platform::CPUDeviceContext>::value>::type
107
+ elementwise_add_grad (const framework::ExecutionContext& ctx,
108
+ const framework::Tensor* x, const framework::Tensor* y,
109
+ const framework::Tensor* out,
110
+ const framework::Tensor* dout, framework::Tensor* dx,
111
+ framework::Tensor* dy) {
112
+ auto blas = math::GetBlas<DeviceContext, T>(ctx);
113
+
114
+ if (dx) {
115
+ blas.VCOPY (dout->numel (), dout->data <T>(),
116
+ dx->mutable_data <T>(ctx.GetPlace ()));
117
+ }
118
+
119
+ if (dy) {
120
+ blas.VCOPY (dout->numel (), dout->data <T>(),
121
+ dy->mutable_data <T>(ctx.GetPlace ()));
122
+ }
123
+ }
124
+
125
+ template <typename DeviceContext, typename T>
126
+ typename std::enable_if<
127
+ !std::is_floating_point<T>::value ||
128
+ !std::is_same<DeviceContext, platform::CPUDeviceContext>::value>::type
129
+ elementwise_add_grad (const framework::ExecutionContext& ctx,
130
+ const framework::Tensor* x, const framework::Tensor* y,
131
+ const framework::Tensor* out,
132
+ const framework::Tensor* dout, framework::Tensor* dx,
133
+ framework::Tensor* dy) {
134
+ default_elementwise_add_grad<DeviceContext, T>(ctx, x, y, out, dout, dx, dy);
135
+ }
136
+
48
137
template <typename DeviceContext, typename T>
49
138
class ElementwiseAddGradKernel : public framework ::OpKernel<T> {
50
139
public:
@@ -57,10 +146,13 @@ class ElementwiseAddGradKernel : public framework::OpKernel<T> {
57
146
auto * dout = ctx.Input <Tensor>(framework::GradVarName (" Out" ));
58
147
auto * dx = ctx.Output <Tensor>(framework::GradVarName (" X" ));
59
148
auto * dy = ctx.Output <Tensor>(framework::GradVarName (" Y" ));
60
- int axis = ctx.Attr <int >(" axis" );
61
- ElemwiseGradCompute<DeviceContext, T, IdentityGrad<T>, IdentityGrad<T>>(
62
- ctx, *x, *y, *out, *dout, axis, dx, dy, IdentityGrad<T>(),
63
- IdentityGrad<T>());
149
+
150
+ if (platform::is_cpu_place (ctx.GetPlace ()) && (x->dims () == y->dims ())) {
151
+ elementwise_add_grad<DeviceContext, T>(ctx, x, y, out, dout, dx, dy);
152
+ } else {
153
+ default_elementwise_add_grad<DeviceContext, T>(ctx, x, y, out, dout, dx,
154
+ dy);
155
+ }
64
156
}
65
157
};
66
158
0 commit comments