Skip to content

Commit 8640456

Browse files
Xrekiphlrain
authored andcommitted
Optimize fused_elewise_activation_grad op. (#18282)
test=release/1.5
1 parent bdba5e7 commit 8640456

File tree

2 files changed

+25
-13
lines changed

2 files changed

+25
-13
lines changed

paddle/fluid/operators/elementwise/elementwise_op_function.h

Lines changed: 13 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -1005,24 +1005,24 @@ template <typename T, typename DX_OP, typename DY_OP, typename DIntermediate_OP,
10051005
bool UseIntermediateOut>
10061006
struct FusedElemwiseAndActGradNoBroadcast {
10071007
HOSTDEVICE void operator()(size_t i) {
1008+
T x_val = x_[i];
1009+
T y_val = y_[i];
1010+
T out_val = out_[i];
1011+
T dout_val = dout_[i];
1012+
T intermediate_out_val = UseIntermediateOut
1013+
? intermediate_out_[i]
1014+
: dx_op_.GetIntermediateOut(x_val, y_val);
10081015
if (dx_ != nullptr) {
1009-
dx_[i] = UseIntermediateOut
1010-
? dx_op_.UseIntermediateOut(
1011-
x_[i], y_[i], intermediate_out_[i], out_[i], dout_[i])
1012-
: dx_op_.Recompute(x_[i], y_[i], out_[i], dout_[i]);
1016+
dx_[i] = dx_op_.UseIntermediateOut(x_val, y_val, intermediate_out_val,
1017+
out_val, dout_val);
10131018
}
10141019
if (dy_ != nullptr) {
1015-
dy_[i] = UseIntermediateOut
1016-
? dy_op_.UseIntermediateOut(
1017-
x_[i], y_[i], intermediate_out_[i], out_[i], dout_[i])
1018-
: dy_op_.Recompute(x_[i], y_[i], out_[i], dout_[i]);
1020+
dy_[i] = dy_op_.UseIntermediateOut(x_val, y_val, intermediate_out_val,
1021+
out_val, dout_val);
10191022
}
10201023
if (dintermediate_ != nullptr) {
1021-
dintermediate_[i] =
1022-
UseIntermediateOut
1023-
? dintermediate_op_.UseIntermediateOut(
1024-
x_[i], intermediate_out_[i], out_[i], dout_[i])
1025-
: dintermediate_op_.Recompute(x_[i], y_[i], out_[i], dout_[i]);
1024+
dintermediate_[i] = dintermediate_op_.UseIntermediateOut(
1025+
x_val, intermediate_out_val, out_val, dout_val);
10261026
}
10271027
}
10281028

paddle/fluid/operators/math/compound_functors.h

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -74,6 +74,8 @@ struct BinaryCompoundGradDxFunctor {
7474
return dout * d_binary_fun_.Dx(x, intermediate_out);
7575
}
7676

77+
inline HOSTDEVICE T GetIntermediateOut(T x, T y) { return unary_fun_(y); }
78+
7779
private:
7880
DBinaryFun d_binary_fun_;
7981
UnaryFun unary_fun_;
@@ -105,6 +107,8 @@ struct BinaryCompoundGradDyFunctor {
105107
}
106108
}
107109

110+
inline HOSTDEVICE T GetIntermediateOut(T x, T y) { return unary_fun_(y); }
111+
108112
private:
109113
DBinaryFun d_binary_fun_;
110114
UnaryFun unary_fun_;
@@ -143,6 +147,8 @@ struct UnaryCompoundGradDxFunctor {
143147
return base * d_binary_fun_.Dx(x, y);
144148
}
145149

150+
inline HOSTDEVICE T GetIntermediateOut(T x, T y) { return binary_fun_(x, y); }
151+
146152
private:
147153
DUnaryFun d_unary_fun_;
148154
BinaryFun binary_fun_;
@@ -181,6 +187,8 @@ struct UnaryCompoundGradDyFunctor {
181187
return base * d_binary_fun_.Dy(x, y);
182188
}
183189

190+
inline HOSTDEVICE T GetIntermediateOut(T x, T y) { return binary_fun_(x, y); }
191+
184192
private:
185193
DUnaryFun d_unary_fun_;
186194
BinaryFun binary_fun_;
@@ -203,6 +211,8 @@ struct BinaryCompoundGradDIntermedaiteOutFunctor {
203211
return dout * d_binary_fun_.Dy(x, intermediate_out);
204212
}
205213

214+
inline HOSTDEVICE T GetIntermediateOut(T x, T y) { return unary_fun_(y); }
215+
206216
private:
207217
DBinaryFun d_binary_fun_;
208218
UnaryFun unary_fun_;
@@ -232,6 +242,8 @@ struct UnaryCompoundGradDIntermediateFunctor {
232242
}
233243
}
234244

245+
inline HOSTDEVICE T GetIntermediateOut(T x, T y) { return binary_fun_(x, y); }
246+
235247
private:
236248
DUnaryFun d_unary_fun_;
237249
BinaryFun binary_fun_;

0 commit comments

Comments
 (0)