@@ -40,80 +40,14 @@ class ElementwiseMulKernel : public framework::OpKernel<T> {
40
40
};
41
41
42
42
template <typename T>
43
- struct ElementwiseMulGradFunctor {
44
- template <typename Device, typename X, typename Y, typename Z, typename dX,
45
- typename dY, typename dZ>
46
- void operator ()(Device d, X x, Y y, Z z, dX dx, dY dy, dZ dz) {
47
- auto x_e = framework::EigenVector<T>::Flatten (*x);
48
- auto y_e = framework::EigenVector<T>::Flatten (*y);
49
- auto dz_e = framework::EigenVector<T>::Flatten (*dz);
50
-
51
- if (dx) {
52
- auto dx_e = framework::EigenVector<T>::Flatten (*dx);
53
- dx_e.device (d) = dz_e * y_e;
54
- }
55
-
56
- if (dy) {
57
- auto dy_e = framework::EigenVector<T>::Flatten (*dy);
58
- dy_e.device (d) = x_e * dz_e;
59
- }
60
- }
43
+ struct IdentityGrad_DX {
44
+ HOSTDEVICE T operator ()(T x, T y, T out, T dout) const { return dout * y; }
61
45
};
62
46
63
47
template <typename T>
64
- struct ElementwiseMulBroadCastGradFunctor {
65
- template <typename Device, typename X, typename Y, typename Z, typename dX,
66
- typename dY, typename dZ, typename Pre, typename N>
67
- void operator ()(Device d, X x, Y y, Z z, dX dx, dY dy, dZ dz, Pre pre, N n) {
68
- auto x_e = framework::EigenVector<T>::Flatten (*x);
69
- auto y_e = framework::EigenVector<T>::Flatten (*y);
70
- auto dz_e = framework::EigenVector<T>::Flatten (*dz);
71
-
72
- auto y_e_bcast = y_e.reshape (Eigen::DSizes<int , 2 >(1 , n))
73
- .broadcast (Eigen::DSizes<int , 2 >(pre, 1 ))
74
- .reshape (Eigen::DSizes<int , 1 >(x_e.size ()));
75
-
76
- if (dx) {
77
- auto dx_e = framework::EigenVector<T>::Flatten (*dx);
78
- dx_e.device (d) = dz_e * y_e_bcast;
79
- }
80
-
81
- if (dy) {
82
- auto dy_e = framework::EigenVector<T>::Flatten (*dy);
83
- dy_e.device (d) = (x_e * dz_e)
84
- .reshape (Eigen::DSizes<int , 2 >(pre, n))
85
- .sum (Eigen::array<int , 1 >{{0 }});
86
- }
87
- }
48
+ struct IdentityGrad_DY {
49
+ HOSTDEVICE T operator ()(T x, T y, T out, T dout) const { return dout * x; }
88
50
};
89
-
90
- template <typename T>
91
- struct ElementwiseMulBroadCast2GradFunctor {
92
- template <typename Device, typename X, typename Y, typename Z, typename dX,
93
- typename dY, typename dZ, typename Pre, typename N, typename Post>
94
- void operator ()(Device d, X x, Y y, Z z, dX dx, dY dy, dZ dz, Pre pre, N n,
95
- Post post) {
96
- auto x_e = framework::EigenVector<T>::Flatten (*x);
97
- auto y_e = framework::EigenVector<T>::Flatten (*y);
98
- auto dz_e = framework::EigenVector<T>::Flatten (*dz);
99
-
100
- auto y_e_bcast = y_e.reshape (Eigen::DSizes<int , 3 >(1 , n, 1 ))
101
- .broadcast (Eigen::DSizes<int , 3 >(pre, 1 , post))
102
- .reshape (Eigen::DSizes<int , 1 >(x_e.size ()));
103
- if (dx) {
104
- auto dx_e = framework::EigenVector<T>::Flatten (*dx);
105
- dx_e.device (d) = dz_e * y_e_bcast;
106
- }
107
-
108
- if (dy) {
109
- auto dy_e = framework::EigenVector<T>::Flatten (*dy);
110
- dy_e.device (d) = (x_e * dz_e)
111
- .reshape (Eigen::DSizes<int , 3 >(pre, n, post))
112
- .sum (Eigen::array<int , 2 >{{0 , 2 }});
113
- }
114
- }
115
- };
116
-
117
51
template <typename DeviceContext, typename T>
118
52
class ElementwiseMulGradKernel : public framework ::OpKernel<T> {
119
53
public:
@@ -127,12 +61,11 @@ class ElementwiseMulGradKernel : public framework::OpKernel<T> {
127
61
auto * dx = ctx.Output <Tensor>(framework::GradVarName (" X" ));
128
62
auto * dy = ctx.Output <Tensor>(framework::GradVarName (" Y" ));
129
63
int axis = ctx.Attr <int >(" axis" );
130
- ElementwiseGradCompute <DeviceContext, T, ElementwiseMulGradFunctor <T>,
131
- ElementwiseMulBroadCastGradFunctor <T>,
132
- ElementwiseMulBroadCast2GradFunctor <T>>(
133
- ctx, x, y, out, dout, axis, dx, dy );
64
+ ElemwiseGradCompute <DeviceContext, T, IdentityGrad_DX <T>,
65
+ IdentityGrad_DY <T>>(ctx, *x, *y, *out, *dout, axis, dx ,
66
+ dy, IdentityGrad_DX <T>(),
67
+ IdentityGrad_DY<T>() );
134
68
}
135
69
};
136
-
137
70
} // namespace operators
138
71
} // namespace paddle
0 commit comments