Skip to content

Commit 5f44813

Browse files
authored
Merge pull request #7529 from JiayiFeng/remove_functor1
remove `functor1` of ElementwiseGradCompute
2 parents f23691d + 6ee8a2e commit 5f44813

File tree

5 files changed

+1
-40
lines changed

5 files changed

+1
-40
lines changed

paddle/operators/elementwise_add_op.h

Lines changed: 0 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -81,23 +81,6 @@ struct ElementwiseAddGradFunctor {
8181
}
8282
};
8383

84-
template <typename T>
85-
struct ElementwiseAddOneGradFunctor {
86-
template <typename Device, typename X, typename Y, typename Z, typename dX,
87-
typename dY, typename dZ>
88-
void operator()(Device d, X x, Y y, Z z, dX dx, dY dy, dZ dz) {
89-
auto dz_e = framework::EigenVector<T>::Flatten(*dz);
90-
if (dx) {
91-
auto dx_e = framework::EigenVector<T>::Flatten(*dx);
92-
dx_e.device(d) = dz_e;
93-
}
94-
if (dy) {
95-
auto dy_e = framework::EigenVector<T>::Flatten(*dy);
96-
dy_e.device(d) = dz_e.sum();
97-
}
98-
}
99-
};
100-
10184
template <typename T>
10285
struct ElementwiseAddBroadCastGradFunctor {
10386
template <typename Device, typename X, typename Y, typename Z, typename dX,
@@ -142,7 +125,6 @@ class ElementwiseAddGradKernel : public framework::OpKernel<T> {
142125
public:
143126
void Compute(const framework::ExecutionContext& ctx) const override {
144127
ElementwiseGradCompute<DeviceContext, T, ElementwiseAddGradFunctor<T>,
145-
ElementwiseAddOneGradFunctor<T>,
146128
ElementwiseAddBroadCastGradFunctor<T>,
147129
ElementwiseAddBroadCast2GradFunctor<T>>(ctx);
148130
}

paddle/operators/elementwise_div_op.h

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -107,7 +107,6 @@ class ElementwiseDivGradKernel : public framework::OpKernel<T> {
107107
public:
108108
void Compute(const framework::ExecutionContext& ctx) const override {
109109
ElementwiseGradCompute<DeviceContext, T, ElementwiseDivGradFunctor<T>,
110-
ElementwiseDivGradFunctor<T>,
111110
ElementwiseDivBroadCastGradFunctor<T>,
112111
ElementwiseDivBroadCast2GradFunctor<T>>(ctx);
113112
}

paddle/operators/elementwise_mul_op.h

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -106,7 +106,6 @@ class ElementwiseMulGradKernel : public framework::OpKernel<T> {
106106
public:
107107
void Compute(const framework::ExecutionContext& ctx) const override {
108108
ElementwiseGradCompute<DeviceContext, T, ElementwiseMulGradFunctor<T>,
109-
ElementwiseMulGradFunctor<T>,
110109
ElementwiseMulBroadCastGradFunctor<T>,
111110
ElementwiseMulBroadCast2GradFunctor<T>>(ctx);
112111
}

paddle/operators/elementwise_op_function.h

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -311,8 +311,7 @@ EIGEN_FUNCTOR(Mul, EIGEN_MUL);
311311
EIGEN_FUNCTOR(Div, EIGEN_DIV);
312312

313313
template <typename DeviceContext, typename T, typename functor,
314-
typename functor1, typename broadcastfunctor,
315-
typename broadcast2functor>
314+
typename broadcastfunctor, typename broadcast2functor>
316315
void ElementwiseGradCompute(const framework::ExecutionContext& ctx) {
317316
using Tensor = framework::Tensor;
318317

paddle/operators/elementwise_sub_op.h

Lines changed: 0 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -43,23 +43,6 @@ struct ElementwiseSubGradFunctor {
4343
}
4444
};
4545

46-
template <typename T>
47-
struct ElementwiseSubOneGradFunctor {
48-
template <typename Device, typename X, typename Y, typename Z, typename dX,
49-
typename dY, typename dZ>
50-
void operator()(Device d, X x, Y y, Z z, dX dx, dY dy, dZ dz) {
51-
auto dz_e = framework::EigenVector<T>::Flatten(*dz);
52-
if (dx) {
53-
auto dx_e = framework::EigenVector<T>::Flatten(*dx);
54-
dx_e.device(d) = dz_e;
55-
}
56-
if (dy) {
57-
auto dy_e = framework::EigenVector<T>::Flatten(*dy);
58-
dy_e.device(d) = (-1.0) * dz_e.sum();
59-
}
60-
}
61-
};
62-
6346
template <typename T>
6447
struct ElementwiseSubBroadCastGradFunctor {
6548
template <typename Device, typename X, typename Y, typename Z, typename dX,
@@ -106,7 +89,6 @@ class ElementwiseSubGradKernel : public framework::OpKernel<T> {
10689
public:
10790
void Compute(const framework::ExecutionContext& ctx) const override {
10891
ElementwiseGradCompute<DeviceContext, T, ElementwiseSubGradFunctor<T>,
109-
ElementwiseSubOneGradFunctor<T>,
11092
ElementwiseSubBroadCastGradFunctor<T>,
11193
ElementwiseSubBroadCast2GradFunctor<T>>(ctx);
11294
}

0 commit comments

Comments
 (0)