Skip to content

Commit c43995e

Browse files
author
chengduo
authored
Merge pull request #8810 from chengduoZH/feature/refine_elementwise_mul
[Speed]Refine elementwise_mul_op
2 parents 266ccaa + a1331f9 commit c43995e

File tree

2 files changed

+9
-76
lines changed

2 files changed

+9
-76
lines changed

paddle/fluid/operators/elementwise_mul_op.h

Lines changed: 8 additions & 75 deletions
Original file line numberDiff line numberDiff line change
@@ -40,80 +40,14 @@ class ElementwiseMulKernel : public framework::OpKernel<T> {
4040
};
4141

4242
template <typename T>
43-
struct ElementwiseMulGradFunctor {
44-
template <typename Device, typename X, typename Y, typename Z, typename dX,
45-
typename dY, typename dZ>
46-
void operator()(Device d, X x, Y y, Z z, dX dx, dY dy, dZ dz) {
47-
auto x_e = framework::EigenVector<T>::Flatten(*x);
48-
auto y_e = framework::EigenVector<T>::Flatten(*y);
49-
auto dz_e = framework::EigenVector<T>::Flatten(*dz);
50-
51-
if (dx) {
52-
auto dx_e = framework::EigenVector<T>::Flatten(*dx);
53-
dx_e.device(d) = dz_e * y_e;
54-
}
55-
56-
if (dy) {
57-
auto dy_e = framework::EigenVector<T>::Flatten(*dy);
58-
dy_e.device(d) = x_e * dz_e;
59-
}
60-
}
43+
struct IdentityGrad_DX {
44+
HOSTDEVICE T operator()(T x, T y, T out, T dout) const { return dout * y; }
6145
};
6246

6347
template <typename T>
64-
struct ElementwiseMulBroadCastGradFunctor {
65-
template <typename Device, typename X, typename Y, typename Z, typename dX,
66-
typename dY, typename dZ, typename Pre, typename N>
67-
void operator()(Device d, X x, Y y, Z z, dX dx, dY dy, dZ dz, Pre pre, N n) {
68-
auto x_e = framework::EigenVector<T>::Flatten(*x);
69-
auto y_e = framework::EigenVector<T>::Flatten(*y);
70-
auto dz_e = framework::EigenVector<T>::Flatten(*dz);
71-
72-
auto y_e_bcast = y_e.reshape(Eigen::DSizes<int, 2>(1, n))
73-
.broadcast(Eigen::DSizes<int, 2>(pre, 1))
74-
.reshape(Eigen::DSizes<int, 1>(x_e.size()));
75-
76-
if (dx) {
77-
auto dx_e = framework::EigenVector<T>::Flatten(*dx);
78-
dx_e.device(d) = dz_e * y_e_bcast;
79-
}
80-
81-
if (dy) {
82-
auto dy_e = framework::EigenVector<T>::Flatten(*dy);
83-
dy_e.device(d) = (x_e * dz_e)
84-
.reshape(Eigen::DSizes<int, 2>(pre, n))
85-
.sum(Eigen::array<int, 1>{{0}});
86-
}
87-
}
48+
struct IdentityGrad_DY {
49+
HOSTDEVICE T operator()(T x, T y, T out, T dout) const { return dout * x; }
8850
};
89-
90-
template <typename T>
91-
struct ElementwiseMulBroadCast2GradFunctor {
92-
template <typename Device, typename X, typename Y, typename Z, typename dX,
93-
typename dY, typename dZ, typename Pre, typename N, typename Post>
94-
void operator()(Device d, X x, Y y, Z z, dX dx, dY dy, dZ dz, Pre pre, N n,
95-
Post post) {
96-
auto x_e = framework::EigenVector<T>::Flatten(*x);
97-
auto y_e = framework::EigenVector<T>::Flatten(*y);
98-
auto dz_e = framework::EigenVector<T>::Flatten(*dz);
99-
100-
auto y_e_bcast = y_e.reshape(Eigen::DSizes<int, 3>(1, n, 1))
101-
.broadcast(Eigen::DSizes<int, 3>(pre, 1, post))
102-
.reshape(Eigen::DSizes<int, 1>(x_e.size()));
103-
if (dx) {
104-
auto dx_e = framework::EigenVector<T>::Flatten(*dx);
105-
dx_e.device(d) = dz_e * y_e_bcast;
106-
}
107-
108-
if (dy) {
109-
auto dy_e = framework::EigenVector<T>::Flatten(*dy);
110-
dy_e.device(d) = (x_e * dz_e)
111-
.reshape(Eigen::DSizes<int, 3>(pre, n, post))
112-
.sum(Eigen::array<int, 2>{{0, 2}});
113-
}
114-
}
115-
};
116-
11751
template <typename DeviceContext, typename T>
11852
class ElementwiseMulGradKernel : public framework::OpKernel<T> {
11953
public:
@@ -127,12 +61,11 @@ class ElementwiseMulGradKernel : public framework::OpKernel<T> {
12761
auto* dx = ctx.Output<Tensor>(framework::GradVarName("X"));
12862
auto* dy = ctx.Output<Tensor>(framework::GradVarName("Y"));
12963
int axis = ctx.Attr<int>("axis");
130-
ElementwiseGradCompute<DeviceContext, T, ElementwiseMulGradFunctor<T>,
131-
ElementwiseMulBroadCastGradFunctor<T>,
132-
ElementwiseMulBroadCast2GradFunctor<T>>(
133-
ctx, x, y, out, dout, axis, dx, dy);
64+
ElemwiseGradCompute<DeviceContext, T, IdentityGrad_DX<T>,
65+
IdentityGrad_DY<T>>(ctx, *x, *y, *out, *dout, axis, dx,
66+
dy, IdentityGrad_DX<T>(),
67+
IdentityGrad_DY<T>());
13468
}
13569
};
136-
13770
} // namespace operators
13871
} // namespace paddle

paddle/fluid/operators/elementwise_op_function.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -301,7 +301,7 @@ struct ElemwiseGradNoBroadcast {
301301
dx_[i] = dx_op_(x_[i], y_[i], out_[i], dout_[i]);
302302
}
303303
if (dy_ != nullptr) {
304-
dy_[i] = dx_op_(x_[i], y_[i], out_[i], dout_[i]);
304+
dy_[i] = dy_op_(x_[i], y_[i], out_[i], dout_[i]);
305305
}
306306
}
307307

0 commit comments

Comments
 (0)