@@ -119,9 +119,30 @@ class LookupTableGradKernel : public framework::OpKernel<T> {
119
119
120
120
auto *d_table_value = d_table->mutable_value ();
121
121
d_table_value->Resize ({ids_num, table_dim[1 ]});
122
+ // FIXME(minqiyang):
122
123
// memory optimization will NOT reuse Tensor with SelectedRows
123
124
// so we could just share the tensor here directly.
124
- d_table_value->ShareDataWith (*d_output);
125
+ // However, the InferVarType method will infer the output SelectedRows
126
+ // to Tensor sometimes, which is a bug, so we will add an attribute
127
+ // here to indicate the inplace and remove this attribute after
128
+ // the InferVarType's bug was fixed
129
+ bool grad_inplace = context.Attr <bool >(" grad_inplace" );
130
+ if (grad_inplace) {
131
+ d_table_value->ShareDataWith (*d_output);
132
+ } else {
133
+ d_table_value->mutable_data <T>(context.GetPlace ());
134
+
135
+ d_table->set_height (table_dim[0 ]);
136
+
137
+ auto *d_output_data = d_output->data <T>();
138
+ auto *d_table_data = d_table_value->data <T>();
139
+
140
+ auto d_output_dims = d_output->dims ();
141
+ PADDLE_ENFORCE_EQ (
142
+ d_table_value->dims (),
143
+ framework::flatten_to_2d (d_output_dims, d_output_dims.size () - 1 ));
144
+ memcpy (d_table_data, d_output_data, sizeof (T) * d_output->numel ());
145
+ }
125
146
} else {
126
147
auto *ids = context.Input <LoDTensor>(" Ids" );
127
148
auto *d_output = context.Input <LoDTensor>(framework::GradVarName (" Out" ));
0 commit comments