Skip to content

Commit e249ad1

Browse files
author
Yancey
authored
Test dist word2vec (#7334)
* test dist word2vec * multiple trainers work
1 parent b5fda27 commit e249ad1

File tree

5 files changed

+115
-6
lines changed

5 files changed

+115
-6
lines changed

paddle/framework/executor.cc

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,7 @@ const std::string kFetchOpType = "fetch";
3535

3636
Executor::Executor(const platform::Place& place) : place_(place) {}
3737

38-
void CreateTensor(Variable* var, proto::VarDesc::VarType var_type) {
38+
static void CreateTensor(Variable* var, proto::VarDesc::VarType var_type) {
3939
if (var_type == proto::VarDesc::LOD_TENSOR) {
4040
var->GetMutable<LoDTensor>();
4141
} else if (var_type == proto::VarDesc::SELECTED_ROWS) {

paddle/framework/executor.h

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -45,7 +45,5 @@ class Executor {
4545
const platform::Place place_;
4646
};
4747

48-
void CreateTensor(Variable* var, proto::VarDesc::VarType var_type);
49-
5048
} // namespace framework
5149
} // namespace paddle

paddle/operators/recv_op.cc

Lines changed: 17 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,20 @@ limitations under the License. */
3232
namespace paddle {
3333
namespace operators {
3434

35+
static void CreateTensorFromMessageType(framework::Variable *var,
36+
sendrecv::VarType var_type) {
37+
if (var_type == sendrecv::VarType::LOD_TENSOR) {
38+
var->GetMutable<framework::LoDTensor>();
39+
} else if (var_type == sendrecv::VarType::SELECTED_ROWS) {
40+
var->GetMutable<framework::SelectedRows>();
41+
} else {
42+
PADDLE_THROW(
43+
"VraibleMessage type %d is not in "
44+
"[LoDTensor, SelectedRows]",
45+
var_type);
46+
}
47+
}
48+
3549
void RunServer(Server **rpc_server,
3650
std::shared_ptr<detail::SendRecvServerImpl> service,
3751
const std::string &server_address) {
@@ -111,10 +125,10 @@ class RecvOp : public framework::OperatorBase {
111125
auto *merged_grad = recv_scope.FindVar(grad_var_name);
112126
if (merged_grad == nullptr) {
113127
auto *ptr = recv_scope.Var(grad_var_name);
114-
framework::CreateTensor(ptr,
115-
framework::ToVarType(merged_grad->Type()));
128+
CreateTensorFromMessageType(ptr, v.second.type());
116129
VLOG(3) << "Create Variable " << grad_var_name
117-
<< " on recv scope, which pointer is " << ptr;
130+
<< " on recv scope, which pointer is " << ptr << " type is "
131+
<< v.second.type();
118132
}
119133

120134
if (trainer_count > 1) {

paddle/operators/sum_op.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -70,6 +70,7 @@ class SumKernel : public framework::OpKernel<T> {
7070
} else if (out_var->IsType<framework::SelectedRows>()) {
7171
PADDLE_ENFORCE(!in_place, "SelectedRows not support inplace sum now");
7272
auto *out = context.Output<SelectedRows>("Out");
73+
out->mutable_rows()->clear();
7374
auto *out_value = out->mutable_value();
7475

7576
// Runtime InferShape
Lines changed: 96 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,96 @@
1+
from __future__ import print_function
2+
import numpy as np
3+
import paddle.v2 as paddle
4+
import paddle.v2.fluid as fluid
5+
import os
6+
7+
PASS_NUM = 100
8+
EMBED_SIZE = 32
9+
HIDDEN_SIZE = 256
10+
N = 5
11+
BATCH_SIZE = 32
12+
IS_SPARSE = True
13+
TRAINERS = 2
14+
15+
word_dict = paddle.dataset.imikolov.build_dict()
16+
dict_size = len(word_dict)
17+
18+
first_word = fluid.layers.data(name='firstw', shape=[1], dtype='int64')
19+
second_word = fluid.layers.data(name='secondw', shape=[1], dtype='int64')
20+
third_word = fluid.layers.data(name='thirdw', shape=[1], dtype='int64')
21+
forth_word = fluid.layers.data(name='forthw', shape=[1], dtype='int64')
22+
next_word = fluid.layers.data(name='nextw', shape=[1], dtype='int64')
23+
24+
embed_first = fluid.layers.embedding(
25+
input=first_word,
26+
size=[dict_size, EMBED_SIZE],
27+
dtype='float32',
28+
is_sparse=IS_SPARSE,
29+
param_attr='shared_w')
30+
embed_second = fluid.layers.embedding(
31+
input=second_word,
32+
size=[dict_size, EMBED_SIZE],
33+
dtype='float32',
34+
is_sparse=IS_SPARSE,
35+
param_attr='shared_w')
36+
embed_third = fluid.layers.embedding(
37+
input=third_word,
38+
size=[dict_size, EMBED_SIZE],
39+
dtype='float32',
40+
is_sparse=IS_SPARSE,
41+
param_attr='shared_w')
42+
embed_forth = fluid.layers.embedding(
43+
input=forth_word,
44+
size=[dict_size, EMBED_SIZE],
45+
dtype='float32',
46+
is_sparse=IS_SPARSE,
47+
param_attr='shared_w')
48+
49+
concat_embed = fluid.layers.concat(
50+
input=[embed_first, embed_second, embed_third, embed_forth], axis=1)
51+
hidden1 = fluid.layers.fc(input=concat_embed, size=HIDDEN_SIZE, act='sigmoid')
52+
predict_word = fluid.layers.fc(input=hidden1, size=dict_size, act='softmax')
53+
cost = fluid.layers.cross_entropy(input=predict_word, label=next_word)
54+
avg_cost = fluid.layers.mean(x=cost)
55+
sgd_optimizer = fluid.optimizer.SGD(learning_rate=0.001)
56+
optimize_ops, params_grads = sgd_optimizer.minimize(avg_cost)
57+
train_reader = paddle.batch(
58+
paddle.dataset.imikolov.train(word_dict, N), BATCH_SIZE)
59+
60+
place = fluid.CPUPlace()
61+
exe = fluid.Executor(place)
62+
63+
t = fluid.DistributeTranspiler()
64+
# all parameter server endpoints list for spliting parameters
65+
pserver_endpoints = os.getenv("PSERVERS")
66+
# server endpoint for current node
67+
current_endpoint = os.getenv("SERVER_ENDPOINT")
68+
# run as trainer or parameter server
69+
training_role = os.getenv("TRAINING_ROLE",
70+
"TRAINER") # get the training role: trainer/pserver
71+
t.transpile(
72+
optimize_ops, params_grads, pservers=pserver_endpoints, trainers=TRAINERS)
73+
if training_role == "PSERVER":
74+
if not current_endpoint:
75+
print("need env SERVER_ENDPOINT")
76+
exit(1)
77+
pserver_prog = t.get_pserver_program(current_endpoint, optimize_ops)
78+
exe.run(fluid.default_startup_program())
79+
exe.run(pserver_prog)
80+
elif training_role == "TRAINER":
81+
feeder = fluid.DataFeeder(
82+
feed_list=[first_word, second_word, third_word, forth_word, next_word],
83+
place=place)
84+
exe.run(fluid.default_startup_program())
85+
for pass_id in range(PASS_NUM):
86+
for data in train_reader():
87+
avg_cost_np = exe.run(fluid.default_main_program(),
88+
feed=feeder.feed(data),
89+
fetch_list=[avg_cost])
90+
print("avg_cost_np", avg_cost_np)
91+
if avg_cost_np[0] < 5.0:
92+
exit(
93+
0) # if avg cost less than 10.0, we think our code is good.
94+
else:
95+
print("environment var TRAINER_ROLE should be TRAINER os PSERVER")
96+
exit(1)

0 commit comments

Comments
 (0)