Skip to content

Commit ead7059

Browse files
committed
Refine code
1 parent ee8e537 commit ead7059

File tree

3 files changed

+40
-66
lines changed

3 files changed

+40
-66
lines changed

paddle/operators/elementwise_max_op.h

Lines changed: 1 addition & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -28,39 +28,7 @@ template <typename DeviceContext, typename T>
2828
class ElementwiseMaxKernel : public framework::OpKernel<T> {
2929
public:
3030
void Compute(const framework::ExecutionContext& ctx) const override {
31-
using Tensor = framework::Tensor;
32-
33-
auto* x = ctx.Input<Tensor>("X");
34-
auto* y = ctx.Input<Tensor>("Y");
35-
auto* z = ctx.Output<Tensor>("Out");
36-
z->mutable_data<T>(ctx.GetPlace());
37-
TransformFunctor<MaxFunctor<T>, T, DeviceContext> functor(
38-
x, y, z, ctx.template device_context<DeviceContext>(), MaxFunctor<T>());
39-
40-
auto x_dims = x->dims();
41-
auto y_dims = y->dims();
42-
PADDLE_ENFORCE_GE(x_dims.size(), y_dims.size(),
43-
"Rank of first input must >= rank of second input.");
44-
45-
if (x_dims == y_dims) {
46-
functor.Run();
47-
return;
48-
}
49-
50-
int axis = ctx.Attr<int>("axis");
51-
axis = (axis == -1 ? x_dims.size() - y_dims.size() : axis);
52-
PADDLE_ENFORCE(axis >= 0 && axis < x_dims.size(),
53-
"Axis should be in range [0, x_dims)");
54-
55-
int pre, n, post;
56-
get_mid_dims(x_dims, y_dims, axis, pre, n, post);
57-
if (post == 1) {
58-
functor.RunRowWise(n, pre);
59-
return;
60-
} else {
61-
functor.RunMidWise(n, pre, post);
62-
return;
63-
}
31+
ElementwiseComputeEx<MaxFunctor<T>, DeviceContext, T>(ctx);
6432
}
6533
};
6634

paddle/operators/elementwise_min_op.h

Lines changed: 1 addition & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -28,39 +28,7 @@ template <typename DeviceContext, typename T>
2828
class ElementwiseMinKernel : public framework::OpKernel<T> {
2929
public:
3030
void Compute(const framework::ExecutionContext& ctx) const override {
31-
using Tensor = framework::Tensor;
32-
33-
auto* x = ctx.Input<Tensor>("X");
34-
auto* y = ctx.Input<Tensor>("Y");
35-
auto* z = ctx.Output<Tensor>("Out");
36-
z->mutable_data<T>(ctx.GetPlace());
37-
TransformFunctor<MinFunctor<T>, T, DeviceContext> functor(
38-
x, y, z, ctx.template device_context<DeviceContext>(), MinFunctor<T>());
39-
40-
auto x_dims = x->dims();
41-
auto y_dims = y->dims();
42-
PADDLE_ENFORCE_GE(x_dims.size(), y_dims.size(),
43-
"Rank of first input must >= rank of second input.");
44-
45-
if (x_dims == y_dims) {
46-
functor.Run();
47-
return;
48-
}
49-
50-
int axis = ctx.Attr<int>("axis");
51-
axis = (axis == -1 ? x_dims.size() - y_dims.size() : axis);
52-
PADDLE_ENFORCE(axis >= 0 && axis < x_dims.size(),
53-
"Axis should be in range [0, x_dims)");
54-
55-
int pre, n, post;
56-
get_mid_dims(x_dims, y_dims, axis, pre, n, post);
57-
if (post == 1) {
58-
functor.RunRowWise(n, pre);
59-
return;
60-
} else {
61-
functor.RunMidWise(n, pre, post);
62-
return;
63-
}
31+
ElementwiseComputeEx<MinFunctor<T>, DeviceContext, T>(ctx);
6432
}
6533
};
6634

paddle/operators/elementwise_op_function.h

Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -356,5 +356,43 @@ void ElementwiseGradCompute(const framework::ExecutionContext& ctx) {
356356
return;
357357
}
358358
}
359+
360+
template <typename Functor, typename DeviceContext, typename T>
361+
void ElementwiseComputeEx(const framework::ExecutionContext& ctx) {
362+
using Tensor = framework::Tensor;
363+
364+
auto* x = ctx.Input<Tensor>("X");
365+
auto* y = ctx.Input<Tensor>("Y");
366+
auto* z = ctx.Output<Tensor>("Out");
367+
z->mutable_data<T>(ctx.GetPlace());
368+
TransformFunctor<Functor, T, DeviceContext> functor(
369+
x, y, z, ctx.template device_context<DeviceContext>(), Functor());
370+
371+
auto x_dims = x->dims();
372+
auto y_dims = y->dims();
373+
PADDLE_ENFORCE_GE(x_dims.size(), y_dims.size(),
374+
"Rank of first input must >= rank of second input.");
375+
376+
if (x_dims == y_dims) {
377+
functor.Run();
378+
return;
379+
}
380+
381+
int axis = ctx.Attr<int>("axis");
382+
axis = (axis == -1 ? x_dims.size() - y_dims.size() : axis);
383+
PADDLE_ENFORCE(axis >= 0 && axis < x_dims.size(),
384+
"Axis should be in range [0, x_dims)");
385+
386+
int pre, n, post;
387+
get_mid_dims(x_dims, y_dims, axis, pre, n, post);
388+
if (post == 1) {
389+
functor.RunRowWise(n, pre);
390+
return;
391+
} else {
392+
functor.RunMidWise(n, pre, post);
393+
return;
394+
}
395+
}
396+
359397
} // namespace operators
360398
} // namespace paddle

0 commit comments

Comments
 (0)