Skip to content

Commit a85bf42

Browse files
authored
Merge pull request #12681 from PaddlePaddle/revert-12554-refine_elementwise_add
Revert "Refine elementwise_add op"
2 parents 8b77448 + 6a2a9a8 commit a85bf42

File tree

2 files changed

+8
-87
lines changed

2 files changed

+8
-87
lines changed

paddle/fluid/operators/elementwise_add_op.cu

Lines changed: 0 additions & 54 deletions
Original file line numberDiff line numberDiff line change
@@ -16,60 +16,6 @@ limitations under the License. */
1616
#include "paddle/fluid/operators/elementwise_add_op.h"
1717
#include "paddle/fluid/platform/float16.h"
1818

19-
namespace paddle {
20-
namespace operators {
21-
22-
template <typename T>
23-
__global__ void ElementwiseAddCUDAKernel(const T *x, const T *y, T *z, int n,
24-
int post, int size) {
25-
int idx_x = threadIdx.x + blockIdx.x * blockDim.x;
26-
if (idx_x < size) {
27-
int idx_y = idx_x / post - (idx_x / (n * post)) * n;
28-
z[idx_x] = x[idx_x] + y[idx_y];
29-
}
30-
}
31-
32-
template <typename T>
33-
class ElementwiseAddKernel<platform::CUDADeviceContext, T>
34-
: public framework::OpKernel<T> {
35-
public:
36-
void Compute(const framework::ExecutionContext &ctx) const override {
37-
using Tensor = framework::Tensor;
38-
39-
const auto x = ctx.Input<Tensor>("X");
40-
const auto y = ctx.Input<Tensor>("Y");
41-
auto z = ctx.Output<Tensor>("Out");
42-
auto *z_data = z->mutable_data<T>(ctx.GetPlace());
43-
44-
auto &device = *(ctx.cuda_device_context().eigen_device());
45-
const framework::DDim &x_dim = x->dims();
46-
framework::DDim y_dim = y->dims();
47-
int size = x->numel();
48-
if (x_dim == y_dim) {
49-
auto dim = framework::make_ddim({size});
50-
auto z_eigen = framework::EigenTensor<T, 1>::From(*z, dim);
51-
auto x_eigen = framework::EigenTensor<T, 1>::From(*x, dim);
52-
auto y_eigen = framework::EigenTensor<T, 1>::From(*y, dim);
53-
z_eigen.device(device) = x_eigen + y_eigen;
54-
} else {
55-
int axis = ctx.Attr<int>("axis");
56-
axis = (axis == -1 ? x_dim.size() - y_dim.size() : axis);
57-
y_dim = trim_trailing_singular_dims(y_dim);
58-
axis = (y_dim.size() == 0) ? x_dim.size() : axis;
59-
int pre, n, post;
60-
get_mid_dims(x_dim, y_dim, axis, &pre, &n, &post);
61-
int threads = 512;
62-
int grids = (size + threads - 1) / threads;
63-
auto stream = ctx.cuda_device_context().stream();
64-
ElementwiseAddCUDAKernel<T><<<grids, threads, 0, stream>>>(
65-
x->data<T>(), y->data<T>(), z_data, n, post, size);
66-
}
67-
}
68-
};
69-
70-
} // namespace operators
71-
} // namespace paddle
72-
7319
namespace ops = paddle::operators;
7420
namespace plat = paddle::platform;
7521

paddle/fluid/operators/elementwise_add_op.h

Lines changed: 8 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -144,41 +144,16 @@ class ElementwiseAddGradKernel : public framework::OpKernel<T> {
144144
auto* dout = ctx.Input<Tensor>(framework::GradVarName("Out"));
145145
auto* dx = ctx.Output<Tensor>(framework::GradVarName("X"));
146146
auto* dy = ctx.Output<Tensor>(framework::GradVarName("Y"));
147+
// skip out, x, y
148+
auto* out = dout;
149+
auto *x = dout, *y = dout;
147150

148-
if (dx != nullptr) {
149-
// In fact, we can just share memory, but it may cause a bug of memory
150-
// optimizer
151-
// dx->ShareDataWith(*dout);
152-
framework::TensorCopy(*dout, ctx.GetPlace(),
153-
ctx.template device_context<DeviceContext>(), dx);
154-
}
155-
156-
if (dy == nullptr) return;
157-
158-
const framework::DDim& x_dim = dout->dims();
159-
framework::DDim y_dim = dy->dims();
160-
if (x_dim == y_dim) {
161-
// dy->ShareDataWith(*dout);
162-
framework::TensorCopy(*dout, ctx.GetPlace(),
163-
ctx.template device_context<DeviceContext>(), dy);
151+
if (platform::is_cpu_place(ctx.GetPlace()) && dx != nullptr &&
152+
dy != nullptr && (dx->dims() == dy->dims())) {
153+
elementwise_add_grad<DeviceContext, T>(ctx, x, y, out, dout, dx, dy);
164154
} else {
165-
dy->mutable_data<T>(ctx.GetPlace());
166-
// Perform reduction to dout to calculate dy
167-
int axis = ctx.Attr<int>("axis");
168-
axis = (axis == -1 ? x_dim.size() - y_dim.size() : axis);
169-
y_dim = trim_trailing_singular_dims(y_dim);
170-
axis = (y_dim.size() == 0) ? x_dim.size() : axis;
171-
172-
auto& device =
173-
*(ctx.template device_context<DeviceContext>().eigen_device());
174-
int pre, n, post;
175-
get_mid_dims(x_dim, y_dim, axis, &pre, &n, &post);
176-
auto eigen_dout = framework::EigenTensor<T, 3>::From(
177-
*dout, framework::make_ddim({pre, n, post}));
178-
auto eigen_dy =
179-
framework::EigenTensor<T, 1>::From(*dy, framework::make_ddim({n}));
180-
eigen_dy.device(device) = eigen_dout.sum(
181-
framework::EigenDim<2>::From(framework::make_ddim({0, 2})));
155+
default_elementwise_add_grad<DeviceContext, T>(ctx, x, y, out, dout, dx,
156+
dy);
182157
}
183158
}
184159
};

0 commit comments

Comments
 (0)