@@ -28,39 +28,7 @@ template <typename DeviceContext, typename T>
28
28
class ElementwiseAddKernel : public framework ::OpKernel<T> {
29
29
public:
30
30
void Compute (const framework::ExecutionContext& ctx) const override {
31
- using Tensor = framework::Tensor;
32
-
33
- auto * x = ctx.Input <Tensor>(" X" );
34
- auto * y = ctx.Input <Tensor>(" Y" );
35
- auto * z = ctx.Output <Tensor>(" Out" );
36
- z->mutable_data <T>(ctx.GetPlace ());
37
- TransformFunctor<AddFunctor<T>, T, DeviceContext> functor (
38
- x, y, z, ctx.template device_context <DeviceContext>(), AddFunctor<T>());
39
-
40
- auto x_dims = x->dims ();
41
- auto y_dims = y->dims ();
42
- PADDLE_ENFORCE_GE (x_dims.size (), y_dims.size (),
43
- " Rank of first input must >= rank of second input." );
44
-
45
- if (x_dims == y_dims) {
46
- functor.Run ();
47
- return ;
48
- }
49
-
50
- int axis = ctx.Attr <int >(" axis" );
51
- axis = (axis == -1 ? x_dims.size () - y_dims.size () : axis);
52
- PADDLE_ENFORCE (axis >= 0 && axis < x_dims.size (),
53
- " Axis should be in range [0, x_dims)" );
54
-
55
- int pre, n, post;
56
- get_mid_dims (x_dims, y_dims, axis, pre, n, post);
57
- if (post == 1 ) {
58
- functor.RunRowWise (n, pre);
59
- return ;
60
- } else {
61
- functor.RunMidWise (n, pre, post);
62
- return ;
63
- }
31
+ ElementwiseComputeEx<AddFunctor<T>, DeviceContext, T>(ctx);
64
32
}
65
33
};
66
34
0 commit comments