@@ -85,6 +85,57 @@ struct IdentityGrad {
85
85
HOSTDEVICE T operator ()(T x, T y, T out, T dout) const { return dout; }
86
86
};
87
87
88
+ template <typename DeviceContext, typename T>
89
+ void default_elementwise_add_grad (const framework::ExecutionContext& ctx,
90
+ const framework::Tensor* x,
91
+ const framework::Tensor* y,
92
+ const framework::Tensor* out,
93
+ const framework::Tensor* dout,
94
+ framework::Tensor* dx,
95
+ framework::Tensor* dy) {
96
+ int axis = ctx.Attr <int >(" axis" );
97
+
98
+ ElemwiseGradCompute<DeviceContext, T, IdentityGrad<T>, IdentityGrad<T>>(
99
+ ctx, *x, *y, *out, *dout, axis, dx, dy, IdentityGrad<T>(),
100
+ IdentityGrad<T>());
101
+ }
102
+
103
+ template <typename DeviceContext, typename T>
104
+ typename std::enable_if<
105
+ std::is_floating_point<T>::value &&
106
+ std::is_same<DeviceContext, platform::CPUDeviceContext>::value>::type
107
+ elementwise_add_grad (const framework::ExecutionContext& ctx,
108
+ const framework::Tensor* x,
109
+ const framework::Tensor* y,
110
+ const framework::Tensor* out,
111
+ const framework::Tensor* dout,
112
+ framework::Tensor* dx, framework::Tensor* dy) {
113
+ auto blas = math::GetBlas<DeviceContext, T>(ctx);
114
+
115
+ if (dx) {
116
+ blas.VCOPY (dout->numel (), dout->data <T>(),
117
+ dx->mutable_data <T>(ctx.GetPlace ()));
118
+ }
119
+
120
+ if (dy) {
121
+ blas.VCOPY (dout->numel (), dout->data <T>(),
122
+ dy->mutable_data <T>(ctx.GetPlace ()));
123
+ }
124
+ }
125
+
126
+ template <typename DeviceContext, typename T>
127
+ typename std::enable_if<
128
+ !std::is_floating_point<T>::value ||
129
+ !std::is_same<DeviceContext, platform::CPUDeviceContext>::value>::type
130
+ elementwise_add_grad (const framework::ExecutionContext& ctx,
131
+ const framework::Tensor* x,
132
+ const framework::Tensor* y,
133
+ const framework::Tensor* out,
134
+ const framework::Tensor* dout,
135
+ framework::Tensor* dx, framework::Tensor* dy) {
136
+ default_elementwise_add_grad<DeviceContext, T>(ctx, x, y, out, dout, dx, dy);
137
+ }
138
+
88
139
template <typename DeviceContext, typename T>
89
140
class ElementwiseAddGradKernel : public framework ::OpKernel<T> {
90
141
public:
@@ -97,24 +148,12 @@ class ElementwiseAddGradKernel : public framework::OpKernel<T> {
97
148
auto * dout = ctx.Input <Tensor>(framework::GradVarName (" Out" ));
98
149
auto * dx = ctx.Output <Tensor>(framework::GradVarName (" X" ));
99
150
auto * dy = ctx.Output <Tensor>(framework::GradVarName (" Y" ));
100
- int axis = ctx.Attr <int >(" axis" );
101
151
102
152
if (platform::is_cpu_place (ctx.GetPlace ()) && (x->dims () == y->dims ())) {
103
- auto blas = math::GetBlas<DeviceContext, T>(ctx);
104
-
105
- if (dx) {
106
- blas.VCOPY (dout->numel (), dout->data <T>(),
107
- dx->mutable_data <T>(ctx.GetPlace ()));
108
- }
109
-
110
- if (dy) {
111
- blas.VCOPY (dout->numel (), dout->data <T>(),
112
- dy->mutable_data <T>(ctx.GetPlace ()));
113
- }
153
+ elementwise_add_grad<DeviceContext, T>(ctx, x, y, out, dout, dx, dy);
114
154
} else {
115
- ElemwiseGradCompute<DeviceContext, T, IdentityGrad<T>, IdentityGrad<T>>(
116
- ctx, *x, *y, *out, *dout, axis, dx, dy, IdentityGrad<T>(),
117
- IdentityGrad<T>());
155
+ default_elementwise_add_grad<DeviceContext, T>(
156
+ ctx, x, y, out, dout, dx, dy);
118
157
}
119
158
}
120
159
};
0 commit comments