Skip to content

Commit 72aef6b

Browse files
committed
sum selected rows check empty
1 parent f13ae13 commit 72aef6b

File tree

1 file changed

+16
-2
lines changed

1 file changed

+16
-2
lines changed

paddle/fluid/operators/sum_op.h

Lines changed: 16 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -116,8 +116,22 @@ class SumKernel : public framework::OpKernel<T> {
116116
auto *out = context.Output<SelectedRows>("Out");
117117
out->mutable_rows()->clear();
118118

119-
math::scatter::MergeAdd<DeviceContext, T> merge_add;
120-
merge_add(context.template device_context<DeviceContext>(), inputs, out);
119+
bool has_data = false;
120+
for (auto &in : inputs) {
121+
if (in->rows().size() > 0) {
122+
has_data = true;
123+
break;
124+
}
125+
}
126+
if (has_data) {
127+
math::scatter::MergeAdd<DeviceContext, T> merge_add;
128+
merge_add(context.template device_context<DeviceContext>(), inputs,
129+
out);
130+
} else {
131+
// no data, just set a empty out tensor.
132+
out->mutable_value()->mutable_data<T>(framework::make_ddim({0}),
133+
context.GetPlace());
134+
}
121135
} else if (out_var->IsType<framework::LoDTensorArray>()) {
122136
auto &out_array = *out_var->GetMutable<framework::LoDTensorArray>();
123137
for (size_t i = in_place ? 1 : 0; i < in_vars.size(); ++i) {

0 commit comments

Comments
 (0)