File tree Expand file tree Collapse file tree 1 file changed +16
-2
lines changed Expand file tree Collapse file tree 1 file changed +16
-2
lines changed Original file line number Diff line number Diff line change @@ -116,8 +116,22 @@ class SumKernel : public framework::OpKernel<T> {
116
116
auto *out = context.Output <SelectedRows>(" Out" );
117
117
out->mutable_rows ()->clear ();
118
118
119
- math::scatter::MergeAdd<DeviceContext, T> merge_add;
120
- merge_add (context.template device_context <DeviceContext>(), inputs, out);
119
+ bool has_data = false ;
120
+ for (auto &in : inputs) {
121
+ if (in->rows ().size () > 0 ) {
122
+ has_data = true ;
123
+ break ;
124
+ }
125
+ }
126
+ if (has_data) {
127
+ math::scatter::MergeAdd<DeviceContext, T> merge_add;
128
+ merge_add (context.template device_context <DeviceContext>(), inputs,
129
+ out);
130
+ } else {
131
+ // no data, just set a empty out tensor.
132
+ out->mutable_value ()->mutable_data <T>(framework::make_ddim ({0 }),
133
+ context.GetPlace ());
134
+ }
121
135
} else if (out_var->IsType <framework::LoDTensorArray>()) {
122
136
auto &out_array = *out_var->GetMutable <framework::LoDTensorArray>();
123
137
for (size_t i = in_place ? 1 : 0 ; i < in_vars.size (); ++i) {
You can’t perform that action at this time.
0 commit comments