Skip to content

Commit 133bac2

Browse files
committed
Accelerate embedding op grad
test=develop
1 parent c26f2b2 commit 133bac2

File tree

1 file changed

+8
-18
lines changed

1 file changed

+8
-18
lines changed

paddle/fluid/operators/lookup_table_op.h

Lines changed: 8 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -68,15 +68,16 @@ class LookupTableKernel : public framework::OpKernel<T> {
6868
const auto *table = table_t.value().data<T>();
6969
auto *output = output_t->mutable_data<T>(context.GetPlace());
7070

71+
auto blas = math::GetBlas<platform::CPUDeviceContext, T>(context);
7172
for (int64_t i = 0; i < ids_numel; ++i) {
7273
if (padding_idx != kNoPadding && ids[i] == padding_idx) {
7374
memset(output + i * row_width, 0, row_width * sizeof(T));
7475
} else {
7576
PADDLE_ENFORCE_GE(ids[i], 0);
7677
auto id_index = table_t.Index(ids[i]);
7778
PADDLE_ENFORCE_GE(id_index, 0, "the input key should be exists.");
78-
memcpy(output + i * row_width, table + id_index * row_width,
79-
row_width * sizeof(T));
79+
blas.VCOPY(row_width, table + id_index * row_width,
80+
output + i * row_width);
8081
}
8182
}
8283
}
@@ -111,27 +112,16 @@ class LookupTableGradKernel : public framework::OpKernel<T> {
111112
auto *ids_data = ids->data<int64_t>();
112113
int64_t ids_num = ids->numel();
113114

114-
framework::Vector<int64_t> new_rows;
115+
std::vector<int64_t> new_rows;
115116
new_rows.reserve(ids_num);
116-
for (int64_t i = 0; i < ids_num; i++) {
117-
new_rows.push_back(ids_data[i]);
118-
}
117+
std::memcpy(new_rows.data(), ids_data, ids_num * sizeof(int64_t));
119118
d_table->set_rows(new_rows);
120119

121120
auto *d_table_value = d_table->mutable_value();
122121
d_table_value->Resize({ids_num, table_dim[1]});
123-
d_table_value->mutable_data<T>(context.GetPlace());
124-
125-
d_table->set_height(table_dim[0]);
126-
127-
auto *d_output_data = d_output->data<T>();
128-
auto *d_table_data = d_table_value->data<T>();
129-
130-
auto d_output_dims = d_output->dims();
131-
PADDLE_ENFORCE_EQ(
132-
d_table_value->dims(),
133-
framework::flatten_to_2d(d_output_dims, d_output_dims.size() - 1));
134-
memcpy(d_table_data, d_output_data, sizeof(T) * d_output->numel());
122+
// memory optimization will NOT reuse Tensor with SelectedRows
123+
// so we could just share the tensor here directly.
124+
d_table_value->ShareDataWith(*d_output);
135125
} else {
136126
auto *ids = context.Input<LoDTensor>("Ids");
137127
auto *d_output = context.Input<LoDTensor>(framework::GradVarName("Out"));

0 commit comments

Comments
 (0)