Skip to content

Commit 38f9b71

Browse files
authored
[cherry-pick] fix fluid.embedding (#25328)
* test=release/1.8, cherry fix fluid.embedding
1 parent b69d064 commit 38f9b71

File tree

5 files changed

+88
-45
lines changed

5 files changed

+88
-45
lines changed

paddle/fluid/operators/distributed/parameter_prefetch.cc

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -209,16 +209,20 @@ void prefetchs(const std::vector<std::string>& id_var_names,
209209
TableAndEndpoints tables;
210210

211211
for (auto& id_name : id_var_names) {
212-
auto& id_tensor = scope.FindVar(id_name)->Get<framework::LoDTensor>();
213-
auto* id_data = id_tensor.data<int64_t>();
212+
auto* id_tensor =
213+
scope.FindVar(id_name)->GetMutable<framework::LoDTensor>();
214+
auto id_dims = id_tensor->dims();
215+
id_tensor->Resize(framework::make_ddim(
216+
{static_cast<int64_t>(id_dims[0] * id_dims[1]), 1}));
217+
auto* id_data = id_tensor->data<int64_t>();
214218
std::vector<int64_t> ids;
215219

216-
for (int64_t i = 0; i < id_tensor.numel(); ++i) {
220+
for (int64_t i = 0; i < id_tensor->numel(); ++i) {
217221
ids.push_back(id_data[i]);
218222
ids_union.push_back(id_data[i]);
219223
}
220224
ids_group.push_back(ids);
221-
ids_lods.push_back(id_tensor.lod());
225+
ids_lods.push_back(id_tensor->lod());
222226
}
223227

224228
std::unordered_set<int64_t> s(ids_union.begin(), ids_union.end());

paddle/fluid/operators/distributed_ops/distributed_lookup_table_op.cc

Lines changed: 41 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@ class DistributedLookupTableOp : public framework::OperatorWithKernel {
2626
public:
2727
using framework::OperatorWithKernel::OperatorWithKernel;
2828

29-
void InferShape(framework::InferShapeContext *ctx) const override {
29+
void InferShape(framework::InferShapeContext* ctx) const override {
3030
PADDLE_ENFORCE(ctx->HasInputs("Ids"),
3131
"Input(Ids) of LookupTableOp should not be null.");
3232
PADDLE_ENFORCE(ctx->HasInput("W"),
@@ -40,18 +40,18 @@ class DistributedLookupTableOp : public framework::OperatorWithKernel {
4040
PADDLE_ENFORCE_EQ(table_dims.size(), 2,
4141
"Only 2 dimensions of the 'Embedding' is supported.");
4242

43-
for (auto &ids_dim : ids_dims) {
43+
for (auto& ids_dim : ids_dims) {
4444
PADDLE_ENFORCE_EQ(ids_dim.size(), 2,
4545
"The dimension of the 'Ids' tensor must be 2.");
46-
PADDLE_ENFORCE_EQ(ids_dim[1], 1,
47-
"The last dimension of the 'Ids' tensor must be 1.");
4846
}
4947

5048
auto lookup_tables =
5149
ctx->Attrs().Get<std::vector<std::string>>("table_names");
5250
auto height_sections =
5351
ctx->Attrs().Get<std::vector<int64_t>>("height_sections");
5452
auto endpoints = ctx->Attrs().Get<std::vector<std::string>>("endpoints");
53+
auto lookup_table_version =
54+
ctx->Attrs().Get<std::string>("lookup_table_version");
5555

5656
PADDLE_ENFORCE(lookup_tables.size() == height_sections.size() &&
5757
lookup_tables.size() == endpoints.size() &&
@@ -61,8 +61,15 @@ class DistributedLookupTableOp : public framework::OperatorWithKernel {
6161

6262
auto outputs_dims = std::vector<framework::DDim>();
6363

64-
for (auto &ids_dim : ids_dims) {
65-
outputs_dims.push_back(framework::make_ddim({ids_dim[0], table_dims[1]}));
64+
for (auto& ids_dim : ids_dims) {
65+
if (lookup_table_version == "lookup_table") {
66+
outputs_dims.push_back(
67+
framework::make_ddim({ids_dim[0], table_dims[1]}));
68+
} else if (lookup_table_version == "lookup_table_v2") {
69+
outputs_dims.push_back(framework::make_ddim(
70+
{static_cast<int64_t>(ids_dim[0]), static_cast<int64_t>(ids_dim[1]),
71+
static_cast<int64_t>(table_dims[1])}));
72+
}
6673
}
6774

6875
ctx->SetOutputsDim("Outputs", outputs_dims);
@@ -71,7 +78,7 @@ class DistributedLookupTableOp : public framework::OperatorWithKernel {
7178

7279
protected:
7380
framework::OpKernelType GetExpectedKernelType(
74-
const framework::ExecutionContext &ctx) const override {
81+
const framework::ExecutionContext& ctx) const override {
7582
return framework::OpKernelType(
7683
framework::proto::VarType::Type(ctx.Attr<int>("dtype")),
7784
ctx.GetPlace());
@@ -81,7 +88,7 @@ class DistributedLookupTableOp : public framework::OperatorWithKernel {
8188
template <typename T>
8289
class DistributedLookupTableKernel : public framework::OpKernel<T> {
8390
public:
84-
void Compute(const framework::ExecutionContext &context) const override {
91+
void Compute(const framework::ExecutionContext& context) const override {
8592
auto ids_vars = context.MultiInputVar("Ids");
8693
auto emb_vars = context.MultiOutput<framework::Tensor>("Embeddings");
8794

@@ -93,10 +100,30 @@ class DistributedLookupTableKernel : public framework::OpKernel<T> {
93100
auto height_sections =
94101
context.Attr<std::vector<int64_t>>("height_sections");
95102
auto endpoints = context.Attr<std::vector<std::string>>("endpoints");
103+
auto lookup_table_version =
104+
context.Attr<std::string>("lookup_table_version");
96105

97106
operators::distributed::prefetchs(
98107
id_names, out_names, embedding_name, false, lookup_tables, endpoints,
99108
height_sections, context, context.scope());
109+
110+
if (lookup_table_version == "lookup_table_v2") {
111+
auto& scope = context.scope();
112+
auto emb_dim =
113+
scope.FindVar(embedding_name)->Get<framework::LoDTensor>().dims()[1];
114+
115+
for (size_t i = 0; i < id_names.size(); ++i) {
116+
auto* id_var = scope.FindVar(id_names[i]);
117+
auto* out_var = scope.FindVar(out_names[i]);
118+
auto* id_tensor = id_var->GetMutable<framework::LoDTensor>();
119+
auto* out_tensor = out_var->GetMutable<framework::LoDTensor>();
120+
121+
auto id_dims = id_tensor->dims();
122+
out_tensor->Resize(framework::make_ddim(
123+
{static_cast<int64_t>(id_dims[0]), static_cast<int64_t>(id_dims[1]),
124+
static_cast<int64_t>(emb_dim)}));
125+
}
126+
}
100127
}
101128
};
102129

@@ -134,6 +161,12 @@ class DistributedLookupTableOpMaker : public framework::OpProtoAndCheckerMaker {
134161

135162
AddAttr<int>("trainer_id", "trainer id from 0 ~ worker_num.").SetDefault(0);
136163

164+
AddAttr<std::string>(
165+
"lookup_table_version",
166+
"(string, default lookup_table) "
167+
"To distinguish between different versions of embedding OP")
168+
.SetDefault(std::string("lookup_table"));
169+
137170
AddAttr<int64_t>("padding_idx",
138171
"(int64, default -1) "
139172
"If the value is -1, it makes no effect to lookup. "

python/paddle/fluid/tests/unittests/dist_simnet_bow.py

Lines changed: 15 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -92,8 +92,8 @@ def train_network(batch_size,
9292
# query
9393
q = fluid.layers.data(
9494
name="query_ids", shape=[1], dtype="int64", lod_level=1)
95-
## embedding
96-
q_emb = fluid.layers.embedding(
95+
# embedding
96+
q_emb = fluid.embedding(
9797
input=q,
9898
is_distributed=is_distributed,
9999
size=[dict_dim, emb_dim],
@@ -104,10 +104,11 @@ def train_network(batch_size,
104104
initializer=fluid.initializer.Constant(value=0.01),
105105
name="__emb__"),
106106
is_sparse=is_sparse)
107-
## vsum
107+
q_emb = fluid.layers.reshape(q_emb, [-1, emb_dim])
108+
# vsum
108109
q_sum = fluid.layers.sequence_pool(input=q_emb, pool_type='sum')
109110
q_ss = fluid.layers.softsign(q_sum)
110-
## fc layer after conv
111+
# fc layer after conv
111112
q_fc = fluid.layers.fc(
112113
input=q_ss,
113114
size=hid_dim,
@@ -120,8 +121,8 @@ def train_network(batch_size,
120121
# pt
121122
pt = fluid.layers.data(
122123
name="pos_title_ids", shape=[1], dtype="int64", lod_level=1)
123-
## embedding
124-
pt_emb = fluid.layers.embedding(
124+
# embedding
125+
pt_emb = fluid.embedding(
125126
input=pt,
126127
is_distributed=is_distributed,
127128
size=[dict_dim, emb_dim],
@@ -132,10 +133,11 @@ def train_network(batch_size,
132133
initializer=fluid.initializer.Constant(value=0.01),
133134
name="__emb__"),
134135
is_sparse=is_sparse)
135-
## vsum
136+
pt_emb = fluid.layers.reshape(pt_emb, [-1, emb_dim])
137+
# vsum
136138
pt_sum = fluid.layers.sequence_pool(input=pt_emb, pool_type='sum')
137139
pt_ss = fluid.layers.softsign(pt_sum)
138-
## fc layer
140+
# fc layer
139141
pt_fc = fluid.layers.fc(
140142
input=pt_ss,
141143
size=hid_dim,
@@ -147,8 +149,8 @@ def train_network(batch_size,
147149
# nt
148150
nt = fluid.layers.data(
149151
name="neg_title_ids", shape=[1], dtype="int64", lod_level=1)
150-
## embedding
151-
nt_emb = fluid.layers.embedding(
152+
# embedding
153+
nt_emb = fluid.embedding(
152154
input=nt,
153155
is_distributed=is_distributed,
154156
size=[dict_dim, emb_dim],
@@ -159,10 +161,11 @@ def train_network(batch_size,
159161
initializer=fluid.initializer.Constant(value=0.01),
160162
name="__emb__"),
161163
is_sparse=is_sparse)
162-
## vsum
164+
nt_emb = fluid.layers.reshape(nt_emb, [-1, emb_dim])
165+
# vsum
163166
nt_sum = fluid.layers.sequence_pool(input=nt_emb, pool_type='sum')
164167
nt_ss = fluid.layers.softsign(nt_sum)
165-
## fc layer
168+
# fc layer
166169
nt_fc = fluid.layers.fc(
167170
input=nt_ss,
168171
size=hid_dim,

python/paddle/fluid/tests/unittests/test_dist_simnet_bow.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -46,7 +46,7 @@ def _setup_config(self):
4646
self._sync_mode = False
4747
self._enforce_place = "CPU"
4848

49-
#FIXME(typhoonzero): fix async tests later
49+
# FIXME(typhoonzero): fix async tests later
5050
def notest_simnet_bow(self):
5151
need_envs = {
5252
"IS_DISTRIBUTED": '0',
@@ -107,7 +107,7 @@ def _setup_config(self):
107107

108108
def test_simnet_bow(self):
109109
need_envs = {
110-
"IS_DISTRIBUTED": '1',
110+
"IS_DISTRIBUTED": '0',
111111
"IS_SPARSE": '1',
112112
'IS_SELF_CONTAINED_LR': '1'
113113
}
@@ -126,7 +126,7 @@ def _setup_config(self):
126126

127127
def test_simnet_bow(self):
128128
need_envs = {
129-
"IS_DISTRIBUTED": '1',
129+
"IS_DISTRIBUTED": '0',
130130
"IS_SPARSE": '1',
131131
'IS_SELF_CONTAINED_LR': '1'
132132
}
@@ -145,7 +145,7 @@ def _setup_config(self):
145145

146146
def test_simnet_bow(self):
147147
need_envs = {
148-
"IS_DISTRIBUTED": '1',
148+
"IS_DISTRIBUTED": '0',
149149
"IS_SPARSE": '1',
150150
'IS_SELF_CONTAINED_LR': '0'
151151
}

python/paddle/fluid/transpiler/distribute_transpiler.py

Lines changed: 20 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -50,8 +50,8 @@
5050
from ..distribute_lookup_table import find_distributed_lookup_table
5151
from . import collective
5252

53-
LOOKUP_TABLE_TYPE = "lookup_table"
54-
LOOKUP_TABLE_GRAD_TYPE = "lookup_table_grad"
53+
LOOKUP_TABLE_TYPE = ["lookup_table", "lookup_table_v2"]
54+
LOOKUP_TABLE_GRAD_TYPE = ["lookup_table_grad", "lookup_table_v2_grad"]
5555
OP_NAME_SCOPE = "op_namescope"
5656
CLIP_OP_NAME_SCOPE = "@CLIP"
5757
OP_ROLE_VAR_ATTR_NAME = core.op_proto_and_checker_maker.kOpRoleVarAttrName()
@@ -199,10 +199,10 @@ class DistributeTranspilerConfig(object):
199199
geo_sgd_need_push_nums = 100
200200

201201
nccl_comm_num = 1
202-
#The picture here illustrates the principle:
203-
#https://github.com/PaddlePaddle/Paddle/pull/17263#discussion_r285411396
202+
# The picture here illustrates the principle:
203+
# https://github.com/PaddlePaddle/Paddle/pull/17263#discussion_r285411396
204204
use_hierarchical_allreduce = False
205-
#Nccl ranks in a node when use hierarchical allreduce, it's set to gpu cards' number in most cases.
205+
# Nccl ranks in a node when use hierarchical allreduce, it's set to gpu cards' number in most cases.
206206
hierarchical_allreduce_inter_nranks = 0
207207

208208
# if mode is collective
@@ -445,7 +445,7 @@ def _transpile_collective(self,
445445

446446
def _get_all_remote_sparse_update_op(self, main_program):
447447
sparse_update_ops = []
448-
sparse_update_op_types = ["lookup_table", "nce"]
448+
sparse_update_op_types = ["lookup_table", "nce", "lookup_table_v2"]
449449
for op in main_program.global_block().ops:
450450
if op.type in sparse_update_op_types and op.attr(
451451
'remote_prefetch') is True:
@@ -475,7 +475,7 @@ def _update_remote_sparse_update_op(self, program,
475475
ops.append(op)
476476
used_ops.append(idx)
477477

478-
if op_type == "lookup_table":
478+
if op_type in LOOKUP_TABLE_TYPE:
479479
all_ops = program.global_block().ops
480480
op_idxs = [all_ops.index(op) for op in ops]
481481
inputs = [
@@ -521,7 +521,8 @@ def _update_remote_sparse_update_op(self, program,
521521
"height_sections": height_sections,
522522
"endpoints": endpoints,
523523
"padding_idx": padding_idx,
524-
"trainer_id": self.trainer_id
524+
"trainer_id": self.trainer_id,
525+
"lookup_table_version": op_type
525526
})
526527
else:
527528
raise ValueError(
@@ -609,10 +610,12 @@ def transpile(self,
609610
)
610611

611612
assert trainers_num > self.config.hierarchical_allreduce_inter_nranks, \
612-
"trainers_num:{} < hierarchical_allreduce_inter_nranks:{}".format(trainers_num, self.config.hierarchical_allreduce_inter_nranks)
613+
"trainers_num:{} < hierarchical_allreduce_inter_nranks:{}".format(
614+
trainers_num, self.config.hierarchical_allreduce_inter_nranks)
613615

614616
assert trainers_num % self.config.hierarchical_allreduce_inter_nranks == 0, \
615-
"trainers_num:{} mod hierarchical_allreduce_inter_nranks:{} != 0".format(trainers_num, self.config.hierarchical_allreduce_inter_nranks)
617+
"trainers_num:{} mod hierarchical_allreduce_inter_nranks:{} != 0".format(
618+
trainers_num, self.config.hierarchical_allreduce_inter_nranks)
616619

617620
self.origin_program._hierarchical_allreduce_inter_nranks = \
618621
int(self.config.hierarchical_allreduce_inter_nranks)
@@ -778,7 +781,7 @@ def transpile(self,
778781
decay_dummy_output = program.global_block().create_var(
779782
name=framework.generate_control_dev_var_name())
780783
if self.config.runtime_split_send_recv:
781-
## async mode, using communicator to merge and send
784+
# async mode, using communicator to merge and send
782785
send_varnames = [self.counter_var.name]
783786
else:
784787
send_varnames = []
@@ -1015,7 +1018,7 @@ def get_trainer_program(self, wait_port=True):
10151018
10161019
- Delete optimizer related op, because parameter updated on Pserver
10171020
- After the op which computed gradient of each parameter, add ``Send_op`` and ``Recv_op``
1018-
1021+
10191022
Args:
10201023
wait_port(bool): Whether to wait for the parameter server to be ready before returning to program,
10211024
default is True
@@ -1072,7 +1075,7 @@ def _get_trainer_startup_program(self, recv_vars, eplist):
10721075
sparse_table_names = self._get_sparse_table_names()
10731076

10741077
# self._fake_init_sparsetable(sparse_table_names)
1075-
#self._delete_trainer_optimizer(is_startup=True)
1078+
# self._delete_trainer_optimizer(is_startup=True)
10761079

10771080
for varname, splited_var in six.iteritems(self.param_var_mapping):
10781081
if varname in sparse_table_names:
@@ -1466,8 +1469,8 @@ def get_startup_program(self,
14661469
Program: parameter server side startup program.
14671470
14681471
Examples:
1469-
.. code-block:: python
1470-
1472+
.. code-block:: python
1473+
14711474
pserver_endpoints = "192.168.0.1:6174,192.168.0.2:6174"
14721475
trainer_endpoints = "192.168.0.1:6174,192.168.0.2:6174"
14731476
current_endpoint = "192.168.0.1:6174"
@@ -2661,7 +2664,7 @@ def _get_optimize_pass(self):
26612664
for op in block.ops:
26622665
if self._is_opt_role_op(op):
26632666
# Todo(chengmo): Whether clip related op belongs to Optimize guard should be discussed
2664-
# delete clip op from opt_ops when run in Parameter Server mode
2667+
# delete clip op from opt_ops when run in Parameter Server mode
26652668
if OP_NAME_SCOPE in op.all_attrs(
26662669
) and CLIP_OP_NAME_SCOPE in op.attr(
26672670
OP_NAME_SCOPE
@@ -2692,7 +2695,7 @@ def _get_optimize_pass(self):
26922695
return opt_ops, params_grads
26932696

26942697
def _get_distribute_update_vars(self):
2695-
#TODO(chengmo): find more powerful and simple way to deal with these special situation
2698+
# TODO(chengmo): find more powerful and simple way to deal with these special situation
26962699
"""
26972700
This Function is used for a special model, like PyramidDnn which has pyramid hash op.
26982701
Some Parameters don't use optimizing op to update its value, but updated in its BP process.

0 commit comments

Comments
 (0)