@@ -68,15 +68,16 @@ class LookupTableKernel : public framework::OpKernel<T> {
68
68
const auto *table = table_t .value ().data <T>();
69
69
auto *output = output_t ->mutable_data <T>(context.GetPlace ());
70
70
71
+ auto blas = math::GetBlas<platform::CPUDeviceContext, T>(context);
71
72
for (int64_t i = 0 ; i < ids_numel; ++i) {
72
73
if (padding_idx != kNoPadding && ids[i] == padding_idx) {
73
74
memset (output + i * row_width, 0 , row_width * sizeof (T));
74
75
} else {
75
76
PADDLE_ENFORCE_GE (ids[i], 0 );
76
77
auto id_index = table_t .Index (ids[i]);
77
78
PADDLE_ENFORCE_GE (id_index, 0 , " the input key should be exists." );
78
- memcpy (output + i * row_width, table + id_index * row_width,
79
- row_width * sizeof (T) );
79
+ blas. VCOPY ( row_width, table + id_index * row_width,
80
+ output + i * row_width );
80
81
}
81
82
}
82
83
}
@@ -111,27 +112,16 @@ class LookupTableGradKernel : public framework::OpKernel<T> {
111
112
auto *ids_data = ids->data <int64_t >();
112
113
int64_t ids_num = ids->numel ();
113
114
114
- framework::Vector <int64_t > new_rows;
115
+ std::vector <int64_t > new_rows;
115
116
new_rows.reserve (ids_num);
116
- for (int64_t i = 0 ; i < ids_num; i++) {
117
- new_rows.push_back (ids_data[i]);
118
- }
117
+ std::memcpy (new_rows.data (), ids_data, ids_num * sizeof (int64_t ));
119
118
d_table->set_rows (new_rows);
120
119
121
120
auto *d_table_value = d_table->mutable_value ();
122
121
d_table_value->Resize ({ids_num, table_dim[1 ]});
123
- d_table_value->mutable_data <T>(context.GetPlace ());
124
-
125
- d_table->set_height (table_dim[0 ]);
126
-
127
- auto *d_output_data = d_output->data <T>();
128
- auto *d_table_data = d_table_value->data <T>();
129
-
130
- auto d_output_dims = d_output->dims ();
131
- PADDLE_ENFORCE_EQ (
132
- d_table_value->dims (),
133
- framework::flatten_to_2d (d_output_dims, d_output_dims.size () - 1 ));
134
- memcpy (d_table_data, d_output_data, sizeof (T) * d_output->numel ());
122
+ // memory optimization will NOT reuse Tensor with SelectedRows
123
+ // so we could just share the tensor here directly.
124
+ d_table_value->ShareDataWith (*d_output);
135
125
} else {
136
126
auto *ids = context.Input <LoDTensor>(" Ids" );
137
127
auto *d_output = context.Input <LoDTensor>(framework::GradVarName (" Out" ));
0 commit comments