Skip to content

Commit 958823f

Browse files
author
Yancey
authored
Merge pull request #11189 from Yancey1989/test_dist_mnist_acc
Add unit test for testing distributed training with mnist
2 parents 06bb642 + 7d9c9a0 commit 958823f

File tree

4 files changed

+213
-2
lines changed

4 files changed

+213
-2
lines changed

python/paddle/dataset/mnist.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -111,7 +111,7 @@ def fetch():
111111
paddle.dataset.common.download(TRAIN_IMAGE_URL, 'mnist', TRAIN_IMAGE_MD5)
112112
paddle.dataset.common.download(TRAIN_LABEL_URL, 'mnist', TRAIN_LABEL_MD5)
113113
paddle.dataset.common.download(TEST_IMAGE_URL, 'mnist', TEST_IMAGE_MD5)
114-
paddle.dataset.common.download(TEST_LABEL_URL, 'mnist', TRAIN_LABEL_MD5)
114+
paddle.dataset.common.download(TEST_LABEL_URL, 'mnist', TEST_LABEL_MD5)
115115

116116

117117
def convert(path):

python/paddle/fluid/tests/unittests/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -51,3 +51,4 @@ py_test_modules(test_dist_train MODULES test_dist_train SERIAL)
5151
py_test_modules(test_parallel_executor_crf MODULES test_parallel_executor_crf SERIAL)
5252
py_test_modules(test_parallel_executor_fetch_feed MODULES test_parallel_executor_fetch_feed SERIAL)
5353
set_tests_properties(test_listen_and_serv_op PROPERTIES TIMEOUT 20)
54+
set_tests_properties(test_dist_mnist PROPERTIES TIMEOUT 180)
Lines changed: 210 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,210 @@
1+
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
import numpy as np
16+
import argparse
17+
import time
18+
import math
19+
20+
import paddle
21+
import paddle.fluid as fluid
22+
import paddle.fluid.profiler as profiler
23+
from paddle.fluid import core
24+
import unittest
25+
from multiprocessing import Process
26+
import os
27+
import signal
28+
29+
SEED = 1
30+
DTYPE = "float32"
31+
paddle.dataset.mnist.fetch()
32+
33+
34+
# random seed must set before configuring the network.
35+
# fluid.default_startup_program().random_seed = SEED
36+
def cnn_model(data):
37+
conv_pool_1 = fluid.nets.simple_img_conv_pool(
38+
input=data,
39+
filter_size=5,
40+
num_filters=20,
41+
pool_size=2,
42+
pool_stride=2,
43+
act="relu")
44+
conv_pool_2 = fluid.nets.simple_img_conv_pool(
45+
input=conv_pool_1,
46+
filter_size=5,
47+
num_filters=50,
48+
pool_size=2,
49+
pool_stride=2,
50+
act="relu")
51+
52+
# TODO(dzhwinter) : refine the initializer and random seed settting
53+
SIZE = 10
54+
input_shape = conv_pool_2.shape
55+
param_shape = [reduce(lambda a, b: a * b, input_shape[1:], 1)] + [SIZE]
56+
scale = (2.0 / (param_shape[0]**2 * SIZE))**0.5
57+
58+
predict = fluid.layers.fc(
59+
input=conv_pool_2,
60+
size=SIZE,
61+
act="softmax",
62+
param_attr=fluid.param_attr.ParamAttr(
63+
initializer=fluid.initializer.NormalInitializer(
64+
loc=0.0, scale=scale)))
65+
return predict
66+
67+
68+
def get_model(batch_size):
69+
# Input data
70+
images = fluid.layers.data(name='pixel', shape=[1, 28, 28], dtype=DTYPE)
71+
label = fluid.layers.data(name='label', shape=[1], dtype='int64')
72+
73+
# Train program
74+
predict = cnn_model(images)
75+
cost = fluid.layers.cross_entropy(input=predict, label=label)
76+
avg_cost = fluid.layers.mean(x=cost)
77+
78+
# Evaluator
79+
batch_size_tensor = fluid.layers.create_tensor(dtype='int64')
80+
batch_acc = fluid.layers.accuracy(
81+
input=predict, label=label, total=batch_size_tensor)
82+
83+
inference_program = fluid.default_main_program().clone()
84+
# Optimization
85+
opt = fluid.optimizer.AdamOptimizer(
86+
learning_rate=0.001, beta1=0.9, beta2=0.999)
87+
88+
# Reader
89+
train_reader = paddle.batch(
90+
paddle.dataset.mnist.train(), batch_size=batch_size)
91+
test_reader = paddle.batch(
92+
paddle.dataset.mnist.test(), batch_size=batch_size)
93+
opt.minimize(avg_cost)
94+
return inference_program, avg_cost, train_reader, test_reader, batch_acc, predict
95+
96+
97+
def get_transpiler(trainer_id, main_program, pserver_endpoints, trainers):
98+
t = fluid.DistributeTranspiler()
99+
t.transpile(
100+
trainer_id=trainer_id,
101+
program=main_program,
102+
pservers=pserver_endpoints,
103+
trainers=trainers)
104+
return t
105+
106+
107+
def run_pserver(pserver_endpoints, trainers, current_endpoint):
108+
get_model(batch_size=20)
109+
t = get_transpiler(0,
110+
fluid.default_main_program(), pserver_endpoints,
111+
trainers)
112+
pserver_prog = t.get_pserver_program(current_endpoint)
113+
startup_prog = t.get_startup_program(current_endpoint, pserver_prog)
114+
115+
place = fluid.CPUPlace()
116+
exe = fluid.Executor(place)
117+
exe.run(startup_prog)
118+
119+
exe.run(pserver_prog)
120+
121+
122+
class TestDistMnist(unittest.TestCase):
123+
def setUp(self):
124+
self._trainers = 1
125+
self._pservers = 1
126+
self._ps_endpoints = "127.0.0.1:9123"
127+
128+
def start_pserver(self, endpoint):
129+
p = Process(
130+
target=run_pserver,
131+
args=(self._ps_endpoints, self._trainers, endpoint))
132+
p.start()
133+
return p.pid
134+
135+
def _wait_ps_ready(self, pid):
136+
retry_times = 5
137+
while True:
138+
assert retry_times >= 0, "wait ps ready failed"
139+
time.sleep(1)
140+
try:
141+
# the listen_and_serv_op would touch a file which contains the listen port
142+
# on the /tmp directory until it was ready to process all the RPC call.
143+
os.stat("/tmp/paddle.%d.port" % pid)
144+
return
145+
except os.error:
146+
retry_times -= 1
147+
148+
def stop_pserver(self, pid):
149+
os.kill(pid, signal.SIGTERM)
150+
151+
def test_with_place(self):
152+
p = fluid.CUDAPlace(0) if core.is_compiled_with_cuda(
153+
) else fluid.CPUPlace()
154+
155+
pserver_pid = self.start_pserver(self._ps_endpoints)
156+
self._wait_ps_ready(pserver_pid)
157+
158+
self.run_trainer(p, 0)
159+
160+
self.stop_pserver(pserver_pid)
161+
162+
def run_trainer(self, place, trainer_id):
163+
test_program, avg_cost, train_reader, test_reader, batch_acc, predict = get_model(
164+
batch_size=20)
165+
t = get_transpiler(trainer_id,
166+
fluid.default_main_program(), self._ps_endpoints,
167+
self._trainers)
168+
169+
trainer_prog = t.get_trainer_program()
170+
171+
exe = fluid.Executor(place)
172+
exe.run(fluid.default_startup_program())
173+
174+
feed_var_list = [
175+
var for var in trainer_prog.global_block().vars.itervalues()
176+
if var.is_data
177+
]
178+
179+
feeder = fluid.DataFeeder(feed_var_list, place)
180+
for pass_id in xrange(10):
181+
for batch_id, data in enumerate(train_reader()):
182+
exe.run(trainer_prog, feed=feeder.feed(data))
183+
184+
if (batch_id + 1) % 10 == 0:
185+
acc_set = []
186+
avg_loss_set = []
187+
for test_data in test_reader():
188+
acc_np, avg_loss_np = exe.run(
189+
program=test_program,
190+
feed=feeder.feed(test_data),
191+
fetch_list=[batch_acc, avg_cost])
192+
acc_set.append(float(acc_np))
193+
avg_loss_set.append(float(avg_loss_np))
194+
# get test acc and loss
195+
acc_val = np.array(acc_set).mean()
196+
avg_loss_val = np.array(avg_loss_set).mean()
197+
if float(acc_val
198+
) > 0.8: # Smaller value to increase CI speed
199+
return
200+
else:
201+
print(
202+
'PassID {0:1}, BatchID {1:04}, Test Loss {2:2.2}, Acc {3:2.2}'.
203+
format(pass_id, batch_id + 1,
204+
float(avg_loss_val), float(acc_val)))
205+
if math.isnan(float(avg_loss_val)):
206+
assert ("got Nan loss, training failed.")
207+
208+
209+
if __name__ == "__main__":
210+
unittest.main()

python/paddle/v2/dataset/mnist.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -112,7 +112,7 @@ def fetch():
112112
paddle.v2.dataset.common.download(TRAIN_IMAGE_URL, 'mnist', TRAIN_IMAGE_MD5)
113113
paddle.v2.dataset.common.download(TRAIN_LABEL_URL, 'mnist', TRAIN_LABEL_MD5)
114114
paddle.v2.dataset.common.download(TEST_IMAGE_URL, 'mnist', TEST_IMAGE_MD5)
115-
paddle.v2.dataset.common.download(TEST_LABEL_URL, 'mnist', TRAIN_LABEL_MD5)
115+
paddle.v2.dataset.common.download(TEST_LABEL_URL, 'mnist', TEST_LABEL_MD5)
116116

117117

118118
def convert(path):

0 commit comments

Comments
 (0)