Skip to content

Commit 209f24a

Browse files
authored
Merge pull request #14051 from velconia/accelerate_embedding_grad
[1.1] Accelerate sparse embedding grad op in CPU device
2 parents 9da9b19 + a8b1753 commit 209f24a

File tree

3 files changed

+38
-19
lines changed

3 files changed

+38
-19
lines changed

paddle/fluid/operators/lookup_table_op.cc

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -81,6 +81,12 @@ class LookupTableOpMaker : public framework::OpProtoAndCheckerMaker {
8181
"Otherwise the given value indicates padding the output "
8282
"with zeros whenever lookup encounters it in Ids.")
8383
.SetDefault(kNoPadding);
84+
// NOTE(minqiyang): grad_inplace is an temporal attribute,
85+
// please do NOT set this attribute in python layer.
86+
AddAttr<bool>("grad_inplace",
87+
"(boolean, default false) "
88+
"If the grad op reuse the input's variable.")
89+
.SetDefault(false);
8490
AddComment(R"DOC(
8591
Lookup Table Operator.
8692

paddle/fluid/operators/lookup_table_op.h

Lines changed: 31 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@ limitations under the License. */
2121
#include "paddle/fluid/framework/lod_tensor.h"
2222
#include "paddle/fluid/framework/op_registry.h"
2323
#include "paddle/fluid/framework/selected_rows.h"
24+
#include "paddle/fluid/operators/math/blas.h"
2425

2526
namespace paddle {
2627
namespace operators {
@@ -68,15 +69,16 @@ class LookupTableKernel : public framework::OpKernel<T> {
6869
const auto *table = table_t.value().data<T>();
6970
auto *output = output_t->mutable_data<T>(context.GetPlace());
7071

72+
auto blas = math::GetBlas<platform::CPUDeviceContext, T>(context);
7173
for (int64_t i = 0; i < ids_numel; ++i) {
7274
if (padding_idx != kNoPadding && ids[i] == padding_idx) {
7375
memset(output + i * row_width, 0, row_width * sizeof(T));
7476
} else {
7577
PADDLE_ENFORCE_GE(ids[i], 0);
7678
auto id_index = table_t.Index(ids[i]);
7779
PADDLE_ENFORCE_GE(id_index, 0, "the input key should be exists.");
78-
memcpy(output + i * row_width, table + id_index * row_width,
79-
row_width * sizeof(T));
80+
blas.VCOPY(row_width, table + id_index * row_width,
81+
output + i * row_width);
8082
}
8183
}
8284
}
@@ -111,27 +113,37 @@ class LookupTableGradKernel : public framework::OpKernel<T> {
111113
auto *ids_data = ids->data<int64_t>();
112114
int64_t ids_num = ids->numel();
113115

114-
framework::Vector<int64_t> new_rows;
115-
new_rows.reserve(ids_num);
116-
for (int64_t i = 0; i < ids_num; i++) {
117-
new_rows.push_back(ids_data[i]);
118-
}
116+
std::vector<int64_t> new_rows;
117+
new_rows.resize(ids_num);
118+
std::memcpy(&new_rows[0], ids_data, ids_num * sizeof(int64_t));
119119
d_table->set_rows(new_rows);
120120

121121
auto *d_table_value = d_table->mutable_value();
122122
d_table_value->Resize({ids_num, table_dim[1]});
123-
d_table_value->mutable_data<T>(context.GetPlace());
124-
125-
d_table->set_height(table_dim[0]);
126-
127-
auto *d_output_data = d_output->data<T>();
128-
auto *d_table_data = d_table_value->data<T>();
129-
130-
auto d_output_dims = d_output->dims();
131-
PADDLE_ENFORCE_EQ(
132-
d_table_value->dims(),
133-
framework::flatten_to_2d(d_output_dims, d_output_dims.size() - 1));
134-
memcpy(d_table_data, d_output_data, sizeof(T) * d_output->numel());
123+
// FIXME(minqiyang):
124+
// memory optimization will NOT reuse Tensor with SelectedRows
125+
// so we could just share the tensor here directly.
126+
// However, the InferVarType method will infer the output SelectedRows
127+
// to Tensor sometimes, which is a bug, so we will add an attribute
128+
// here to indicate the inplace and remove this attribute after
129+
// the InferVarType's bug was fixed
130+
bool grad_inplace = context.Attr<bool>("grad_inplace");
131+
if (grad_inplace) {
132+
d_table_value->ShareDataWith(*d_output);
133+
} else {
134+
d_table_value->mutable_data<T>(context.GetPlace());
135+
136+
d_table->set_height(table_dim[0]);
137+
138+
auto *d_output_data = d_output->data<T>();
139+
auto *d_table_data = d_table_value->data<T>();
140+
141+
auto d_output_dims = d_output->dims();
142+
PADDLE_ENFORCE_EQ(
143+
d_table_value->dims(),
144+
framework::flatten_to_2d(d_output_dims, d_output_dims.size() - 1));
145+
memcpy(d_table_data, d_output_data, sizeof(T) * d_output->numel());
146+
}
135147
} else {
136148
auto *ids = context.Input<LoDTensor>("Ids");
137149
auto *d_output = context.Input<LoDTensor>(framework::GradVarName("Out"));

python/paddle/fluid/tests/unittests/dist_transformer.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1159,6 +1159,7 @@ def prepare_encoder(src_word,
11591159
name=pos_enc_param_name,
11601160
trainable=False,
11611161
initializer=fluid.initializer.ConstantInitializer(0.001)))
1162+
src_pos_enc.stop_gradient = True
11621163
enc_input = src_word_emb + src_pos_enc
11631164
return layers.dropout(
11641165
enc_input,

0 commit comments

Comments
 (0)