@@ -21,6 +21,7 @@ limitations under the License. */
21
21
#include " paddle/fluid/framework/lod_tensor.h"
22
22
#include " paddle/fluid/framework/op_registry.h"
23
23
#include " paddle/fluid/framework/selected_rows.h"
24
+ #include " paddle/fluid/operators/math/blas.h"
24
25
25
26
namespace paddle {
26
27
namespace operators {
@@ -68,15 +69,16 @@ class LookupTableKernel : public framework::OpKernel<T> {
68
69
const auto *table = table_t .value ().data <T>();
69
70
auto *output = output_t ->mutable_data <T>(context.GetPlace ());
70
71
72
+ auto blas = math::GetBlas<platform::CPUDeviceContext, T>(context);
71
73
for (int64_t i = 0 ; i < ids_numel; ++i) {
72
74
if (padding_idx != kNoPadding && ids[i] == padding_idx) {
73
75
memset (output + i * row_width, 0 , row_width * sizeof (T));
74
76
} else {
75
77
PADDLE_ENFORCE_GE (ids[i], 0 );
76
78
auto id_index = table_t .Index (ids[i]);
77
79
PADDLE_ENFORCE_GE (id_index, 0 , " the input key should be exists." );
78
- memcpy (output + i * row_width, table + id_index * row_width,
79
- row_width * sizeof (T) );
80
+ blas. VCOPY ( row_width, table + id_index * row_width,
81
+ output + i * row_width );
80
82
}
81
83
}
82
84
}
@@ -111,27 +113,37 @@ class LookupTableGradKernel : public framework::OpKernel<T> {
111
113
auto *ids_data = ids->data <int64_t >();
112
114
int64_t ids_num = ids->numel ();
113
115
114
- framework::Vector<int64_t > new_rows;
115
- new_rows.reserve (ids_num);
116
- for (int64_t i = 0 ; i < ids_num; i++) {
117
- new_rows.push_back (ids_data[i]);
118
- }
116
+ std::vector<int64_t > new_rows;
117
+ new_rows.resize (ids_num);
118
+ std::memcpy (&new_rows[0 ], ids_data, ids_num * sizeof (int64_t ));
119
119
d_table->set_rows (new_rows);
120
120
121
121
auto *d_table_value = d_table->mutable_value ();
122
122
d_table_value->Resize ({ids_num, table_dim[1 ]});
123
- d_table_value->mutable_data <T>(context.GetPlace ());
124
-
125
- d_table->set_height (table_dim[0 ]);
126
-
127
- auto *d_output_data = d_output->data <T>();
128
- auto *d_table_data = d_table_value->data <T>();
129
-
130
- auto d_output_dims = d_output->dims ();
131
- PADDLE_ENFORCE_EQ (
132
- d_table_value->dims (),
133
- framework::flatten_to_2d (d_output_dims, d_output_dims.size () - 1 ));
134
- memcpy (d_table_data, d_output_data, sizeof (T) * d_output->numel ());
123
+ // FIXME(minqiyang):
124
+ // memory optimization will NOT reuse Tensor with SelectedRows
125
+ // so we could just share the tensor here directly.
126
+ // However, the InferVarType method will infer the output SelectedRows
127
+ // to Tensor sometimes, which is a bug, so we will add an attribute
128
+ // here to indicate the inplace and remove this attribute after
129
+ // the InferVarType's bug was fixed
130
+ bool grad_inplace = context.Attr <bool >(" grad_inplace" );
131
+ if (grad_inplace) {
132
+ d_table_value->ShareDataWith (*d_output);
133
+ } else {
134
+ d_table_value->mutable_data <T>(context.GetPlace ());
135
+
136
+ d_table->set_height (table_dim[0 ]);
137
+
138
+ auto *d_output_data = d_output->data <T>();
139
+ auto *d_table_data = d_table_value->data <T>();
140
+
141
+ auto d_output_dims = d_output->dims ();
142
+ PADDLE_ENFORCE_EQ (
143
+ d_table_value->dims (),
144
+ framework::flatten_to_2d (d_output_dims, d_output_dims.size () - 1 ));
145
+ memcpy (d_table_data, d_output_data, sizeof (T) * d_output->numel ());
146
+ }
135
147
} else {
136
148
auto *ids = context.Input <LoDTensor>(" Ids" );
137
149
auto *d_output = context.Input <LoDTensor>(framework::GradVarName (" Out" ));
0 commit comments