Skip to content

Commit 1a3b38a

Browse files
committed
Polish code
test=develop
1 parent 133bac2 commit 1a3b38a

File tree

2 files changed

+26
-1
lines changed

2 files changed

+26
-1
lines changed

paddle/fluid/operators/lookup_table_op.cc

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -81,6 +81,10 @@ class LookupTableOpMaker : public framework::OpProtoAndCheckerMaker {
8181
"Otherwise the given value indicates padding the output "
8282
"with zeros whenever lookup encounters it in Ids.")
8383
.SetDefault(kNoPadding);
84+
AddAttr<bool>("grad_inplace",
85+
"(boolean, default false) "
86+
"If the grad op reuse the input's variable.")
87+
.SetDefault(false);
8488
AddComment(R"DOC(
8589
Lookup Table Operator.
8690

paddle/fluid/operators/lookup_table_op.h

Lines changed: 22 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -119,9 +119,30 @@ class LookupTableGradKernel : public framework::OpKernel<T> {
119119

120120
auto *d_table_value = d_table->mutable_value();
121121
d_table_value->Resize({ids_num, table_dim[1]});
122+
// FIXME(minqiyang):
122123
// memory optimization will NOT reuse Tensor with SelectedRows
123124
// so we could just share the tensor here directly.
124-
d_table_value->ShareDataWith(*d_output);
125+
// However, the InferVarType method will infer the output SelectedRows
126+
// to Tensor sometimes, which is a bug, so we will add an attribute
127+
// here to indicate the inplace and remove this attribute after
128+
// the InferVarType's bug was fixed
129+
bool grad_inplace = context.Attr<bool>("grad_inplace");
130+
if (grad_inplace) {
131+
d_table_value->ShareDataWith(*d_output);
132+
} else {
133+
d_table_value->mutable_data<T>(context.GetPlace());
134+
135+
d_table->set_height(table_dim[0]);
136+
137+
auto *d_output_data = d_output->data<T>();
138+
auto *d_table_data = d_table_value->data<T>();
139+
140+
auto d_output_dims = d_output->dims();
141+
PADDLE_ENFORCE_EQ(
142+
d_table_value->dims(),
143+
framework::flatten_to_2d(d_output_dims, d_output_dims.size() - 1));
144+
memcpy(d_table_data, d_output_data, sizeof(T) * d_output->numel());
145+
}
125146
} else {
126147
auto *ids = context.Input<LoDTensor>("Ids");
127148
auto *d_output = context.Input<LoDTensor>(framework::GradVarName("Out"));

0 commit comments

Comments
 (0)