Skip to content

Commit b58c5ee

Browse files
authored
Merge pull request #7469 from putcn/book_demo_distributed_fit_a_line
Add book demo distributed fit a line
2 parents 5d9dcfc + cadb95f commit b58c5ee

File tree

1 file changed

+62
-0
lines changed

1 file changed

+62
-0
lines changed
Lines changed: 62 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,62 @@
1+
import numpy as np
2+
import paddle.v2 as paddle
3+
import paddle.v2.fluid as fluid
4+
import os
5+
6+
x = fluid.layers.data(name='x', shape=[13], dtype='float32')
7+
8+
y_predict = fluid.layers.fc(input=x, size=1, act=None)
9+
10+
y = fluid.layers.data(name='y', shape=[1], dtype='float32')
11+
12+
cost = fluid.layers.square_error_cost(input=y_predict, label=y)
13+
avg_cost = fluid.layers.mean(x=cost)
14+
15+
sgd_optimizer = fluid.optimizer.SGD(learning_rate=0.001)
16+
optimize_ops, params_grads = sgd_optimizer.minimize(avg_cost)
17+
18+
BATCH_SIZE = 20
19+
20+
train_reader = paddle.batch(
21+
paddle.reader.shuffle(
22+
paddle.dataset.uci_housing.train(), buf_size=500),
23+
batch_size=BATCH_SIZE)
24+
25+
place = fluid.CPUPlace()
26+
feeder = fluid.DataFeeder(place=place, feed_list=[x, y])
27+
exe = fluid.Executor(place)
28+
29+
t = fluid.DistributeTranspiler()
30+
# all parameter server endpoints list for spliting parameters
31+
pserver_endpoints = os.getenv("PSERVERS")
32+
# server endpoint for current node
33+
current_endpoint = os.getenv("SERVER_ENDPOINT")
34+
# run as trainer or parameter server
35+
training_role = os.getenv("TRAINING_ROLE",
36+
"TRAINER") # get the training role: trainer/pserver
37+
t.transpile(optimize_ops, params_grads, pservers=pserver_endpoints, trainers=2)
38+
39+
if training_role == "PSERVER":
40+
if not current_endpoint:
41+
print("need env SERVER_ENDPOINT")
42+
exit(1)
43+
pserver_prog = t.get_pserver_program(current_endpoint, optimize_ops)
44+
exe.run(fluid.default_startup_program())
45+
exe.run(pserver_prog)
46+
else:
47+
trainer_prog = t.get_trainer_program()
48+
49+
exe.run(fluid.default_startup_program())
50+
51+
PASS_NUM = 100
52+
for pass_id in range(PASS_NUM):
53+
fluid.io.save_persistables(exe, "./fit_a_line.model/")
54+
fluid.io.load_persistables(exe, "./fit_a_line.model/")
55+
for data in train_reader():
56+
avg_loss_value, = exe.run(trainer_prog,
57+
feed=feeder.feed(data),
58+
fetch_list=[avg_cost])
59+
60+
if avg_loss_value[0] < 10.0:
61+
exit(0)
62+
exit(1)

0 commit comments

Comments
 (0)