Skip to content

Commit 5dbb2e9

Browse files
authored
Small changes for sum_op to avoid zero setting. (#13923)
1 parent 1250678 commit 5dbb2e9

File tree

1 file changed

+20
-6
lines changed

1 file changed

+20
-6
lines changed

paddle/fluid/operators/sum_op.h

Lines changed: 20 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -43,17 +43,31 @@ class SumKernel : public framework::OpKernel<T> {
4343
out->mutable_data<T>(context.GetPlace());
4444
}
4545
auto result = EigenVector<T>::Flatten(*out);
46+
auto &place =
47+
*context.template device_context<DeviceContext>().eigen_device();
48+
int start = in_place ? 1 : 0;
4649
if (!in_place) {
47-
math::SetConstant<DeviceContext, T> constant_functor;
48-
constant_functor(context.template device_context<DeviceContext>(), out,
49-
0.0);
50+
if ((in_num >= 2) && in_vars[0]->IsType<framework::LoDTensor>() &&
51+
in_vars[1]->IsType<framework::LoDTensor>()) {
52+
auto &in_0 = in_vars[0]->Get<framework::LoDTensor>();
53+
auto &in_1 = in_vars[1]->Get<framework::LoDTensor>();
54+
if (in_0.numel() && in_1.numel()) {
55+
auto in_0_e = EigenVector<T>::Flatten(in_0);
56+
auto in_1_e = EigenVector<T>::Flatten(in_1);
57+
result.device(place) = in_0_e + in_1_e;
58+
start = 2;
59+
}
60+
}
61+
if (start != 2) {
62+
math::SetConstant<DeviceContext, T> constant_functor;
63+
constant_functor(context.template device_context<DeviceContext>(),
64+
out, 0.0);
65+
}
5066
}
5167

5268
math::SelectedRowsAddToTensor<DeviceContext, T> functor;
53-
auto &place =
54-
*context.template device_context<DeviceContext>().eigen_device();
5569
// If in_place, just skip the first tensor
56-
for (size_t i = in_place ? 1 : 0; i < in_num; i++) {
70+
for (size_t i = start; i < in_num; i++) {
5771
if (in_vars[i]->IsType<framework::LoDTensor>()) {
5872
auto &in_t = in_vars[i]->Get<framework::LoDTensor>();
5973
if (in_t.numel() == 0) {

0 commit comments

Comments
 (0)