@@ -144,41 +144,16 @@ class ElementwiseAddGradKernel : public framework::OpKernel<T> {
144
144
auto * dout = ctx.Input <Tensor>(framework::GradVarName (" Out" ));
145
145
auto * dx = ctx.Output <Tensor>(framework::GradVarName (" X" ));
146
146
auto * dy = ctx.Output <Tensor>(framework::GradVarName (" Y" ));
147
+ // skip out, x, y
148
+ auto * out = dout;
149
+ auto *x = dout, *y = dout;
147
150
148
- if (dx != nullptr ) {
149
- // In fact, we can just share memory, but it may cause a bug of memory
150
- // optimizer
151
- // dx->ShareDataWith(*dout);
152
- framework::TensorCopy (*dout, ctx.GetPlace (),
153
- ctx.template device_context <DeviceContext>(), dx);
154
- }
155
-
156
- if (dy == nullptr ) return ;
157
-
158
- const framework::DDim& x_dim = dout->dims ();
159
- framework::DDim y_dim = dy->dims ();
160
- if (x_dim == y_dim) {
161
- // dy->ShareDataWith(*dout);
162
- framework::TensorCopy (*dout, ctx.GetPlace (),
163
- ctx.template device_context <DeviceContext>(), dy);
151
+ if (platform::is_cpu_place (ctx.GetPlace ()) && dx != nullptr &&
152
+ dy != nullptr && (dx->dims () == dy->dims ())) {
153
+ elementwise_add_grad<DeviceContext, T>(ctx, x, y, out, dout, dx, dy);
164
154
} else {
165
- dy->mutable_data <T>(ctx.GetPlace ());
166
- // Perform reduction to dout to calculate dy
167
- int axis = ctx.Attr <int >(" axis" );
168
- axis = (axis == -1 ? x_dim.size () - y_dim.size () : axis);
169
- y_dim = trim_trailing_singular_dims (y_dim);
170
- axis = (y_dim.size () == 0 ) ? x_dim.size () : axis;
171
-
172
- auto & device =
173
- *(ctx.template device_context <DeviceContext>().eigen_device ());
174
- int pre, n, post;
175
- get_mid_dims (x_dim, y_dim, axis, &pre, &n, &post);
176
- auto eigen_dout = framework::EigenTensor<T, 3 >::From (
177
- *dout, framework::make_ddim ({pre, n, post}));
178
- auto eigen_dy =
179
- framework::EigenTensor<T, 1 >::From (*dy, framework::make_ddim ({n}));
180
- eigen_dy.device (device) = eigen_dout.sum (
181
- framework::EigenDim<2 >::From (framework::make_ddim ({0 , 2 })));
155
+ default_elementwise_add_grad<DeviceContext, T>(ctx, x, y, out, dout, dx,
156
+ dy);
182
157
}
183
158
}
184
159
};
0 commit comments