Skip to content

Commit 5fc8326

Browse files
committed
Add parallel accuracy test
1 parent 494c262 commit 5fc8326

File tree

1 file changed

+22
-2
lines changed

1 file changed

+22
-2
lines changed

python/paddle/fluid/tests/unittests/test_parallel_executor.py

Lines changed: 22 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -203,10 +203,14 @@ def check_network_convergence(self,
203203
iter=10,
204204
batch_size=None,
205205
allow_op_delay=False,
206-
feed_dict={}):
206+
feed_dict={},
207+
random_seed=None,
208+
use_parallel_executor=True):
207209
main = fluid.Program()
208210
startup = fluid.Program()
209211
with fluid.program_guard(main, startup):
212+
if seed is not None:
213+
startup.random_seed(random_seed)
210214
loss = method(use_feed=len(feed_dict) > 0)
211215
adam = fluid.optimizer.Adam()
212216
adam.minimize(loss)
@@ -217,7 +221,11 @@ def check_network_convergence(self,
217221
startup_exe = fluid.Executor(place)
218222
startup_exe.run(startup)
219223

220-
exe = fluid.ParallelExecutor(True, loss_name=loss.name)
224+
if use_parallel_executor:
225+
exe = fluid.ParallelExecutor(True, loss_name=loss.name)
226+
else:
227+
exe = fluid.Executor(place=place)
228+
221229
if batch_size is not None:
222230
batch_size *= fluid.core.get_cuda_device_count()
223231
begin = time.time()
@@ -238,6 +246,7 @@ def check_network_convergence(self,
238246

239247
print first_loss, last_loss
240248
# self.assertGreater(first_loss[0], last_loss[0])
249+
return first_loss, last_loss
241250

242251

243252
class TestMNIST(TestParallelExecutorBase):
@@ -267,6 +276,17 @@ def test_simple_fc(self):
267276
simple_fc_net, feed_dict={"image": img,
268277
"label": label})
269278

279+
def test_simple_fc_parallel_accuracy(self):
280+
single_first_loss, single_last_loss = self.check_network_convergence(
281+
simple_fc_net, random_seed=0, use_parallel_executor=False)
282+
parallel_first_loss, parallel_last_loss = self.check_network_convergence(
283+
simple_fc_net, random_seed=0, use_parallel_executor=True)
284+
print("FUCK")
285+
print('single_first_loss=', single_first_loss)
286+
print('single_last_loss=', single_last_loss)
287+
print('parallel_first_loss=', parallel_first_loss)
288+
print('parallel_last_loss=', parallel_last_loss)
289+
270290
def test_batchnorm_fc(self):
271291
self.check_network_convergence(fc_with_batchnorm)
272292
img = numpy.zeros(shape=[32, 784], dtype='float32')

0 commit comments

Comments
 (0)