Skip to content

Commit c083ee7

Browse files
authored
Merge pull request #9950 from JiayiFeng/add_parallel_executor_tests
Add parallel executor tests
2 parents 61f4baa + e84d3a7 commit c083ee7

File tree

2 files changed

+58
-14
lines changed

2 files changed

+58
-14
lines changed

python/paddle/fluid/parallel_executor.py

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
import multiprocessing
1717
import framework
1818
import executor
19+
import warnings
1920
import sys
2021

2122
__all__ = ['ParallelExecutor']
@@ -62,8 +63,8 @@ def __init__(self,
6263
main_program=test_program,
6364
share_vars_from=train_exe)
6465
65-
train_loss, = train_exe.run([loss.name], feed_dict=feed_dict)
66-
test_loss, = test_exe.run([loss.name], feed_dict=feed_dict)
66+
train_loss, = train_exe.run([loss.name], feed=feed_dict)
67+
test_loss, = test_exe.run([loss.name], feed=feed_dict)
6768
"""
6869

6970
self._places = []
@@ -103,8 +104,8 @@ def __init__(self,
103104

104105
self.persistable_vars = [
105106
v.name
106-
for v in filter(lambda var: \
107-
var.persistable and var.type != core.VarDesc.VarType.RAW,
107+
for v in filter(
108+
lambda var: var.persistable and var.type != core.VarDesc.VarType.RAW,
108109
main.list_vars())
109110
]
110111

@@ -163,7 +164,7 @@ def run(self, fetch_list, feed=None, feed_dict=None):
163164
Returns: fetched result list.
164165
165166
"""
166-
if feed is None:
167+
if feed is None and feed_dict is not None:
167168
feed = feed_dict
168169
print >> sys.stderr, "`feed_dict` is deprecated. Please use `feed=`"
169170

python/paddle/fluid/tests/unittests/test_parallel_executor.py

Lines changed: 52 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -200,14 +200,29 @@ class TestParallelExecutorBase(unittest.TestCase):
200200
def check_network_convergence(self,
201201
method,
202202
memory_opt=True,
203-
iter=10,
203+
iter=50,
204204
batch_size=None,
205205
allow_op_delay=False,
206-
feed_dict=None):
206+
feed_dict=None,
207+
seed=None,
208+
use_parallel_executor=True):
209+
def run_executor(exe, feed, fetch_list, program=None):
210+
if isinstance(exe, fluid.ParallelExecutor):
211+
res = exe.run(fetch_list=fetch_list, feed=feed)
212+
elif isinstance(exe, fluid.Executor):
213+
if program is None:
214+
program = fluid.default_main_program()
215+
res = exe.run(program=program, feed=feed, fetch_list=fetch_list)
216+
else:
217+
raise ValueError('Unkown type exe')
218+
return res
219+
207220
main = fluid.Program()
208221
startup = fluid.Program()
209222
startup.random_seed = 1 # Fix random seed
210223
with fluid.program_guard(main, startup):
224+
if seed is not None:
225+
startup.random_seed = seed
211226
loss = method(use_feed=feed_dict is not None)
212227
adam = fluid.optimizer.Adam()
213228
adam.minimize(loss)
@@ -217,18 +232,24 @@ def check_network_convergence(self,
217232
startup_exe = fluid.Executor(place)
218233
startup_exe.run(startup)
219234

220-
exe = fluid.ParallelExecutor(
221-
True, loss_name=loss.name, allow_op_delay=allow_op_delay)
235+
if use_parallel_executor:
236+
exe = fluid.ParallelExecutor(
237+
True, loss_name=loss.name, allow_op_delay=allow_op_delay)
238+
else:
239+
exe = fluid.Executor(place=place)
240+
222241
if batch_size is not None:
223242
batch_size *= fluid.core.get_cuda_device_count()
224243
begin = time.time()
225-
first_loss, = exe.run([loss.name], feed=feed_dict)
244+
first_loss, = run_executor(
245+
exe=exe, feed=feed_dict, fetch_list=[loss.name])
226246
first_loss = numpy.array(first_loss)
227247

228248
for i in xrange(iter):
229-
exe.run([], feed=feed_dict)
249+
run_executor(exe=exe, feed=feed_dict, fetch_list=[])
230250

231-
last_loss, = exe.run([loss.name], feed=feed_dict)
251+
last_loss, = run_executor(
252+
exe=exe, feed=feed_dict, fetch_list=[loss.name])
232253
end = time.time()
233254

234255
if batch_size is not None:
@@ -239,6 +260,7 @@ def check_network_convergence(self,
239260

240261
print first_loss, last_loss
241262
# self.assertGreater(first_loss[0], last_loss[0])
263+
return first_loss, last_loss
242264

243265

244266
class TestMNIST(TestParallelExecutorBase):
@@ -268,6 +290,27 @@ def test_simple_fc(self):
268290
simple_fc_net, feed_dict={"image": img,
269291
"label": label})
270292

293+
def test_simple_fc_parallel_accuracy(self):
294+
img = numpy.zeros(shape=[32, 784], dtype='float32')
295+
label = numpy.ones(shape=[32, 1], dtype='int64')
296+
single_first_loss, single_last_loss = self.check_network_convergence(
297+
method=simple_fc_net,
298+
seed=1000,
299+
feed_dict={"image": img,
300+
"label": label},
301+
use_parallel_executor=False)
302+
parallel_first_loss, parallel_last_loss = self.check_network_convergence(
303+
method=simple_fc_net,
304+
seed=1000,
305+
feed_dict={"image": img,
306+
"label": label},
307+
use_parallel_executor=True)
308+
309+
for p_f in parallel_first_loss:
310+
self.assertAlmostEquals(p_f, single_first_loss[0], delta=1e-6)
311+
for p_l in parallel_last_loss:
312+
self.assertAlmostEquals(p_l, single_last_loss[0], delta=1e-6)
313+
271314
def test_batchnorm_fc(self):
272315
self.check_network_convergence(fc_with_batchnorm)
273316
img = numpy.zeros(shape=[32, 784], dtype='float32')
@@ -496,10 +539,10 @@ def test_parallel_testing(self):
496539
share_vars_from=train_exe)
497540

498541
for i in xrange(5):
499-
test_loss, = test_exe.run([loss.name], feed_dict=feed_dict)
542+
test_loss, = test_exe.run([loss.name], feed=feed_dict)
500543
test_loss = numpy.array(test_loss)
501544

502-
train_loss, = train_exe.run([loss.name], feed_dict=feed_dict)
545+
train_loss, = train_exe.run([loss.name], feed=feed_dict)
503546
train_loss = numpy.array(train_loss)
504547
self.assertTrue(
505548
numpy.allclose(

0 commit comments

Comments
 (0)