Skip to content

Commit f3cdeb9

Browse files
author
chengduo
authored
Merge pull request #8820 from chengduoZH/feature/refine_elementwise_
[Speed] Refine elementwise sub,div,min,max gradient functor
2 parents e1348e1 + 8b30fad commit f3cdeb9

File tree

5 files changed

+34
-277
lines changed

5 files changed

+34
-277
lines changed

paddle/fluid/operators/elementwise_div_op.h

Lines changed: 7 additions & 72 deletions
Original file line numberDiff line numberDiff line change
@@ -41,77 +41,14 @@ class ElementwiseDivKernel : public framework::OpKernel<T> {
4141
};
4242

4343
template <typename T>
44-
struct ElementwiseDivGradFunctor {
45-
template <typename Device, typename X, typename Y, typename Z, typename dX,
46-
typename dY, typename dZ>
47-
void operator()(Device d, X x, Y y, Z z, dX dx, dY dy, dZ dz) {
48-
auto y_e = framework::EigenVector<T>::Flatten(*y);
49-
auto z_e = framework::EigenVector<T>::Flatten(*z);
50-
auto dz_e = framework::EigenVector<T>::Flatten(*dz);
51-
52-
if (dx) {
53-
auto dx_e = framework::EigenVector<T>::Flatten(*dx);
54-
dx_e.device(d) = dz_e / y_e;
55-
}
56-
57-
if (dy) {
58-
auto dy_e = framework::EigenVector<T>::Flatten(*dy);
59-
dy_e.device(d) = -1.0 * dz_e * z_e / y_e;
60-
}
61-
}
62-
};
63-
64-
template <typename T>
65-
struct ElementwiseDivBroadCastGradFunctor {
66-
template <typename Device, typename X, typename Y, typename Z, typename dX,
67-
typename dY, typename dZ, typename Pre, typename N>
68-
void operator()(Device d, X x, Y y, Z z, dX dx, dY dy, dZ dz, Pre pre, N n) {
69-
auto x_e = framework::EigenVector<T>::Flatten(*x);
70-
auto y_e = framework::EigenVector<T>::Flatten(*y);
71-
auto dz_e = framework::EigenVector<T>::Flatten(*dz);
72-
73-
auto y_e_bcast = y_e.reshape(Eigen::DSizes<int, 2>(1, n))
74-
.broadcast(Eigen::DSizes<int, 2>(pre, 1))
75-
.reshape(Eigen::DSizes<int, 1>(x_e.size()));
76-
77-
if (dx) {
78-
auto dx_e = framework::EigenVector<T>::Flatten(*dx);
79-
dx_e.device(d) = dz_e / y_e_bcast;
80-
}
81-
82-
if (dy) {
83-
auto dy_e = framework::EigenVector<T>::Flatten(*dy);
84-
dy_e.device(d) = (-1.0 * (x_e * dz_e) / (y_e_bcast * y_e_bcast))
85-
.reshape(Eigen::DSizes<int, 2>(pre, n))
86-
.sum(Eigen::array<int, 1>{{0}});
87-
}
88-
}
44+
struct DivGradDX {
45+
HOSTDEVICE T operator()(T x, T y, T out, T dout) const { return dout / y; }
8946
};
9047

9148
template <typename T>
92-
struct ElementwiseDivBroadCast2GradFunctor {
93-
template <typename Device, typename X, typename Y, typename Z, typename dX,
94-
typename dY, typename dZ, typename Pre, typename N, typename Post>
95-
void operator()(Device d, X x, Y y, Z z, dX dx, dY dy, dZ dz, Pre pre, N n,
96-
Post post) {
97-
auto x_e = framework::EigenVector<T>::Flatten(*x);
98-
auto y_e = framework::EigenVector<T>::Flatten(*y);
99-
auto dz_e = framework::EigenVector<T>::Flatten(*dz);
100-
101-
auto y_e_bcast = y_e.reshape(Eigen::DSizes<int, 3>(1, n, 1))
102-
.broadcast(Eigen::DSizes<int, 3>(pre, 1, post))
103-
.reshape(Eigen::DSizes<int, 1>(x_e.size()));
104-
if (dx) {
105-
auto dx_e = framework::EigenVector<T>::Flatten(*dx);
106-
dx_e.device(d) = dz_e / y_e_bcast;
107-
}
108-
109-
if (dy) {
110-
auto dy_e = framework::EigenVector<T>::Flatten(*dy);
111-
dy_e.device(d) = (-1.0 * (x_e * dz_e) / (y_e_bcast * y_e_bcast))
112-
.reshape(Eigen::DSizes<int, 3>(pre, n, post))
113-
.sum(Eigen::array<int, 2>{{0, 2}});
114-
}
49+
struct DivGradDY {
50+
HOSTDEVICE T operator()(T x, T y, T out, T dout) const {
51+
return -dout * x / (y * y);
11552
}
11653
};
11754

@@ -128,10 +65,8 @@ class ElementwiseDivGradKernel : public framework::OpKernel<T> {
12865
auto* dx = ctx.Output<Tensor>(framework::GradVarName("X"));
12966
auto* dy = ctx.Output<Tensor>(framework::GradVarName("Y"));
13067
int axis = ctx.Attr<int>("axis");
131-
ElementwiseGradCompute<DeviceContext, T, ElementwiseDivGradFunctor<T>,
132-
ElementwiseDivBroadCastGradFunctor<T>,
133-
ElementwiseDivBroadCast2GradFunctor<T>>(
134-
ctx, x, y, out, dout, axis, dx, dy);
68+
ElemwiseGradCompute<DeviceContext, T, DivGradDX<T>, DivGradDY<T>>(
69+
ctx, *x, *y, *out, *dout, axis, dx, dy, DivGradDX<T>(), DivGradDY<T>());
13570
}
13671
};
13772

paddle/fluid/operators/elementwise_max_op.h

Lines changed: 8 additions & 71 deletions
Original file line numberDiff line numberDiff line change
@@ -41,76 +41,16 @@ class ElementwiseMaxKernel : public framework::OpKernel<T> {
4141
};
4242

4343
template <typename T>
44-
struct ElementwiseMaxGradFunctor {
45-
template <typename Device, typename X, typename Y, typename Z, typename dX,
46-
typename dY, typename dZ>
47-
void operator()(Device d, X x, Y y, Z z, dX dx, dY dy, dZ dz) {
48-
auto x_e = framework::EigenVector<T>::Flatten(*x);
49-
auto y_e = framework::EigenVector<T>::Flatten(*y);
50-
auto dz_e = framework::EigenVector<T>::Flatten(*dz);
51-
52-
if (dx) {
53-
auto dx_e = framework::EigenVector<T>::Flatten(*dx);
54-
dx_e.device(d) = (x_e > y_e).template cast<T>() * dz_e;
55-
}
56-
if (dy) {
57-
auto dy_e = framework::EigenVector<T>::Flatten(*dy);
58-
dy_e.device(d) = (x_e <= y_e).template cast<T>() * dz_e;
59-
}
44+
struct MaxGradDx {
45+
HOSTDEVICE T operator()(T x, T y, T out, T dout) const {
46+
return dout * (x > y);
6047
}
6148
};
6249

6350
template <typename T>
64-
struct ElementwiseMaxBroadCastGradFunctor {
65-
template <typename Device, typename X, typename Y, typename Z, typename dX,
66-
typename dY, typename dZ, typename Pre, typename N>
67-
void operator()(Device d, X x, Y y, Z z, dX dx, dY dy, dZ dz, Pre pre, N n) {
68-
auto x_e = framework::EigenVector<T>::Flatten(*x);
69-
auto y_e = framework::EigenVector<T>::Flatten(*y);
70-
auto dz_e = framework::EigenVector<T>::Flatten(*dz);
71-
72-
auto y_e_bcast = y_e.reshape(Eigen::DSizes<int, 2>(1, n))
73-
.broadcast(Eigen::DSizes<int, 2>(pre, 1))
74-
.reshape(Eigen::DSizes<int, 1>(x_e.size()));
75-
76-
if (dx) {
77-
auto dx_e = framework::EigenVector<T>::Flatten(*dx);
78-
dx_e.device(d) = (x_e > y_e_bcast).template cast<T>() * dz_e;
79-
}
80-
81-
if (dy) {
82-
auto dy_e = framework::EigenVector<T>::Flatten(*dy);
83-
dy_e.device(d) = ((x_e <= y_e_bcast).template cast<T>() * dz_e)
84-
.reshape(Eigen::DSizes<int, 2>(pre, n))
85-
.sum(Eigen::array<int, 1>{{0}});
86-
}
87-
}
88-
};
89-
90-
template <typename T>
91-
struct ElementwiseMaxBroadCast2GradFunctor {
92-
template <typename Device, typename X, typename Y, typename Z, typename dX,
93-
typename dY, typename dZ, typename Pre, typename N, typename Post>
94-
void operator()(Device d, X x, Y y, Z z, dX dx, dY dy, dZ dz, Pre pre, N n,
95-
Post post) {
96-
auto x_e = framework::EigenVector<T>::Flatten(*x);
97-
auto y_e = framework::EigenVector<T>::Flatten(*y);
98-
auto dz_e = framework::EigenVector<T>::Flatten(*dz);
99-
100-
auto y_e_bcast = y_e.reshape(Eigen::DSizes<int, 3>(1, n, 1))
101-
.broadcast(Eigen::DSizes<int, 3>(pre, 1, post))
102-
.reshape(Eigen::DSizes<int, 1>(x_e.size()));
103-
if (dx) {
104-
auto dx_e = framework::EigenVector<T>::Flatten(*dx);
105-
dx_e.device(d) = (x_e > y_e_bcast).template cast<T>() * dz_e;
106-
}
107-
108-
if (dy) {
109-
auto dy_e = framework::EigenVector<T>::Flatten(*dy);
110-
dy_e.device(d) = ((x_e <= y_e_bcast).template cast<T>() * dz_e)
111-
.reshape(Eigen::DSizes<int, 3>(pre, n, post))
112-
.sum(Eigen::array<int, 2>{{0, 2}});
113-
}
51+
struct MaxGradDy {
52+
HOSTDEVICE T operator()(T x, T y, T out, T dout) const {
53+
return dout * (x <= y);
11454
}
11555
};
11656

@@ -127,12 +67,9 @@ class ElementwiseMaxGradKernel : public framework::OpKernel<T> {
12767
auto* dx = ctx.Output<Tensor>(framework::GradVarName("X"));
12868
auto* dy = ctx.Output<Tensor>(framework::GradVarName("Y"));
12969
int axis = ctx.Attr<int>("axis");
130-
ElementwiseGradCompute<DeviceContext, T, ElementwiseMaxGradFunctor<T>,
131-
ElementwiseMaxBroadCastGradFunctor<T>,
132-
ElementwiseMaxBroadCast2GradFunctor<T>>(
133-
ctx, x, y, out, dout, axis, dx, dy);
70+
ElemwiseGradCompute<DeviceContext, T, MaxGradDx<T>, MaxGradDy<T>>(
71+
ctx, *x, *y, *out, *dout, axis, dx, dy, MaxGradDx<T>(), MaxGradDy<T>());
13472
}
13573
};
136-
13774
} // namespace operators
13875
} // namespace paddle

paddle/fluid/operators/elementwise_min_op.h

Lines changed: 8 additions & 71 deletions
Original file line numberDiff line numberDiff line change
@@ -41,76 +41,16 @@ class ElementwiseMinKernel : public framework::OpKernel<T> {
4141
};
4242

4343
template <typename T>
44-
struct ElementwiseMinGradFunctor {
45-
template <typename Device, typename X, typename Y, typename Z, typename dX,
46-
typename dY, typename dZ>
47-
void operator()(Device d, X x, Y y, Z z, dX dx, dY dy, dZ dz) {
48-
auto x_e = framework::EigenVector<T>::Flatten(*x);
49-
auto y_e = framework::EigenVector<T>::Flatten(*y);
50-
auto dz_e = framework::EigenVector<T>::Flatten(*dz);
51-
52-
if (dx) {
53-
auto dx_e = framework::EigenVector<T>::Flatten(*dx);
54-
dx_e.device(d) = (x_e < y_e).template cast<T>() * dz_e;
55-
}
56-
if (dy) {
57-
auto dy_e = framework::EigenVector<T>::Flatten(*dy);
58-
dy_e.device(d) = (x_e >= y_e).template cast<T>() * dz_e;
59-
}
44+
struct MinGradDx {
45+
HOSTDEVICE T operator()(T x, T y, T out, T dout) const {
46+
return dout * (x < y);
6047
}
6148
};
6249

6350
template <typename T>
64-
struct ElementwiseMinBroadCastGradFunctor {
65-
template <typename Device, typename X, typename Y, typename Z, typename dX,
66-
typename dY, typename dZ, typename Pre, typename N>
67-
void operator()(Device d, X x, Y y, Z z, dX dx, dY dy, dZ dz, Pre pre, N n) {
68-
auto x_e = framework::EigenVector<T>::Flatten(*x);
69-
auto y_e = framework::EigenVector<T>::Flatten(*y);
70-
auto dz_e = framework::EigenVector<T>::Flatten(*dz);
71-
72-
auto y_e_bcast = y_e.reshape(Eigen::DSizes<int, 2>(1, n))
73-
.broadcast(Eigen::DSizes<int, 2>(pre, 1))
74-
.reshape(Eigen::DSizes<int, 1>(x_e.size()));
75-
76-
if (dx) {
77-
auto dx_e = framework::EigenVector<T>::Flatten(*dx);
78-
dx_e.device(d) = (x_e < y_e_bcast).template cast<T>() * dz_e;
79-
}
80-
81-
if (dy) {
82-
auto dy_e = framework::EigenVector<T>::Flatten(*dy);
83-
dy_e.device(d) = ((x_e >= y_e_bcast).template cast<T>() * dz_e)
84-
.reshape(Eigen::DSizes<int, 2>(pre, n))
85-
.sum(Eigen::array<int, 1>{{0}});
86-
}
87-
}
88-
};
89-
90-
template <typename T>
91-
struct ElementwiseMinBroadCast2GradFunctor {
92-
template <typename Device, typename X, typename Y, typename Z, typename dX,
93-
typename dY, typename dZ, typename Pre, typename N, typename Post>
94-
void operator()(Device d, X x, Y y, Z z, dX dx, dY dy, dZ dz, Pre pre, N n,
95-
Post post) {
96-
auto x_e = framework::EigenVector<T>::Flatten(*x);
97-
auto y_e = framework::EigenVector<T>::Flatten(*y);
98-
auto dz_e = framework::EigenVector<T>::Flatten(*dz);
99-
100-
auto y_e_bcast = y_e.reshape(Eigen::DSizes<int, 3>(1, n, 1))
101-
.broadcast(Eigen::DSizes<int, 3>(pre, 1, post))
102-
.reshape(Eigen::DSizes<int, 1>(x_e.size()));
103-
if (dx) {
104-
auto dx_e = framework::EigenVector<T>::Flatten(*dx);
105-
dx_e.device(d) = (x_e < y_e_bcast).template cast<T>() * dz_e;
106-
}
107-
108-
if (dy) {
109-
auto dy_e = framework::EigenVector<T>::Flatten(*dy);
110-
dy_e.device(d) = ((x_e >= y_e_bcast).template cast<T>() * dz_e)
111-
.reshape(Eigen::DSizes<int, 3>(pre, n, post))
112-
.sum(Eigen::array<int, 2>{{0, 2}});
113-
}
51+
struct MinGradDy {
52+
HOSTDEVICE T operator()(T x, T y, T out, T dout) const {
53+
return dout * (x >= y);
11454
}
11555
};
11656

@@ -127,12 +67,9 @@ class ElementwiseMinGradKernel : public framework::OpKernel<T> {
12767
auto* dx = ctx.Output<Tensor>(framework::GradVarName("X"));
12868
auto* dy = ctx.Output<Tensor>(framework::GradVarName("Y"));
12969
int axis = ctx.Attr<int>("axis");
130-
ElementwiseGradCompute<DeviceContext, T, ElementwiseMinGradFunctor<T>,
131-
ElementwiseMinBroadCastGradFunctor<T>,
132-
ElementwiseMinBroadCast2GradFunctor<T>>(
133-
ctx, x, y, out, dout, axis, dx, dy);
70+
ElemwiseGradCompute<DeviceContext, T, MinGradDx<T>, MinGradDy<T>>(
71+
ctx, *x, *y, *out, *dout, axis, dx, dy, MinGradDx<T>(), MinGradDy<T>());
13472
}
13573
};
136-
13774
} // namespace operators
13875
} // namespace paddle

paddle/fluid/operators/elementwise_mul_op.h

Lines changed: 5 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -40,14 +40,15 @@ class ElementwiseMulKernel : public framework::OpKernel<T> {
4040
};
4141

4242
template <typename T>
43-
struct IdentityGrad_DX {
43+
struct MulGradDX {
4444
HOSTDEVICE T operator()(T x, T y, T out, T dout) const { return dout * y; }
4545
};
4646

4747
template <typename T>
48-
struct IdentityGrad_DY {
48+
struct MulGradDY {
4949
HOSTDEVICE T operator()(T x, T y, T out, T dout) const { return dout * x; }
5050
};
51+
5152
template <typename DeviceContext, typename T>
5253
class ElementwiseMulGradKernel : public framework::OpKernel<T> {
5354
public:
@@ -61,10 +62,8 @@ class ElementwiseMulGradKernel : public framework::OpKernel<T> {
6162
auto* dx = ctx.Output<Tensor>(framework::GradVarName("X"));
6263
auto* dy = ctx.Output<Tensor>(framework::GradVarName("Y"));
6364
int axis = ctx.Attr<int>("axis");
64-
ElemwiseGradCompute<DeviceContext, T, IdentityGrad_DX<T>,
65-
IdentityGrad_DY<T>>(ctx, *x, *y, *out, *dout, axis, dx,
66-
dy, IdentityGrad_DX<T>(),
67-
IdentityGrad_DY<T>());
65+
ElemwiseGradCompute<DeviceContext, T, MulGradDX<T>, MulGradDY<T>>(
66+
ctx, *x, *y, *out, *dout, axis, dx, dy, MulGradDX<T>(), MulGradDY<T>());
6867
}
6968
};
7069
} // namespace operators

0 commit comments

Comments
 (0)