Skip to content

Commit f59599a

Browse files
committed
refine elementwise_add_op
1 parent ead7059 commit f59599a

File tree

1 file changed

+1
-33
lines changed

1 file changed

+1
-33
lines changed

paddle/operators/elementwise_add_op.h

Lines changed: 1 addition & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -28,39 +28,7 @@ template <typename DeviceContext, typename T>
2828
class ElementwiseAddKernel : public framework::OpKernel<T> {
2929
public:
3030
void Compute(const framework::ExecutionContext& ctx) const override {
31-
using Tensor = framework::Tensor;
32-
33-
auto* x = ctx.Input<Tensor>("X");
34-
auto* y = ctx.Input<Tensor>("Y");
35-
auto* z = ctx.Output<Tensor>("Out");
36-
z->mutable_data<T>(ctx.GetPlace());
37-
TransformFunctor<AddFunctor<T>, T, DeviceContext> functor(
38-
x, y, z, ctx.template device_context<DeviceContext>(), AddFunctor<T>());
39-
40-
auto x_dims = x->dims();
41-
auto y_dims = y->dims();
42-
PADDLE_ENFORCE_GE(x_dims.size(), y_dims.size(),
43-
"Rank of first input must >= rank of second input.");
44-
45-
if (x_dims == y_dims) {
46-
functor.Run();
47-
return;
48-
}
49-
50-
int axis = ctx.Attr<int>("axis");
51-
axis = (axis == -1 ? x_dims.size() - y_dims.size() : axis);
52-
PADDLE_ENFORCE(axis >= 0 && axis < x_dims.size(),
53-
"Axis should be in range [0, x_dims)");
54-
55-
int pre, n, post;
56-
get_mid_dims(x_dims, y_dims, axis, pre, n, post);
57-
if (post == 1) {
58-
functor.RunRowWise(n, pre);
59-
return;
60-
} else {
61-
functor.RunMidWise(n, pre, post);
62-
return;
63-
}
31+
ElementwiseComputeEx<AddFunctor<T>, DeviceContext, T>(ctx);
6432
}
6533
};
6634

0 commit comments

Comments
 (0)