@@ -200,17 +200,29 @@ class TestParallelExecutorBase(unittest.TestCase):
200
200
def check_network_convergence (self ,
201
201
method ,
202
202
memory_opt = True ,
203
- iter = 10 ,
203
+ iter = 50 ,
204
204
batch_size = None ,
205
205
allow_op_delay = False ,
206
206
feed_dict = {},
207
207
seed = None ,
208
208
use_parallel_executor = True ):
209
+ def run_executor (exe , feed , fetch_list , program = None ):
210
+ if isinstance (exe , fluid .ParallelExecutor ):
211
+ res = exe .run (fetch_list = fetch_list , feed = feed )
212
+ elif isinstance (exe , fluid .Executor ):
213
+ if program is None :
214
+ program = fluid .default_main_program ()
215
+ res = exe .run (program = program , feed = feed , fetch_list = fetch_list )
216
+ else :
217
+ raise ValueError ('Unkown type exe' )
218
+ return res
219
+
209
220
main = fluid .Program ()
210
221
startup = fluid .Program ()
211
222
with fluid .program_guard (main , startup ):
212
223
if seed is not None :
213
224
startup .random_seed = seed
225
+ main .random_seed = seed
214
226
loss = method (use_feed = len (feed_dict ) > 0 )
215
227
adam = fluid .optimizer .Adam ()
216
228
adam .minimize (loss )
@@ -229,13 +241,15 @@ def check_network_convergence(self,
229
241
if batch_size is not None :
230
242
batch_size *= fluid .core .get_cuda_device_count ()
231
243
begin = time .time ()
232
- first_loss , = exe .run ([loss .name ], feed = feed_dict )
244
+ first_loss , = run_executor (
245
+ exe = exe , feed = feed_dict , fetch_list = [loss .name ])
233
246
first_loss = numpy .array (first_loss )
234
247
235
248
for i in xrange (iter ):
236
- exe . run ([] , feed = feed_dict )
249
+ run_executor ( exe = exe , feed = feed_dict , fetch_list = [] )
237
250
238
- last_loss , = exe .run ([loss .name ], feed = feed_dict )
251
+ last_loss , = run_executor (
252
+ exe = exe , feed = feed_dict , fetch_list = [loss .name ])
239
253
end = time .time ()
240
254
241
255
if batch_size is not None :
@@ -277,14 +291,25 @@ def test_simple_fc(self):
277
291
"label" : label })
278
292
279
293
def test_simple_fc_parallel_accuracy (self ):
280
- #single_first_loss, single_last_loss = self.check_network_convergence(
281
- # simple_fc_net, seed=0, use_parallel_executor=False)
282
- #parallel_first_loss, parallel_last_loss = self.check_network_convergence(
283
- # simple_fc_net, seed=0, use_parallel_executor=True)
284
- print ('single_first_loss=' , single_first_loss )
285
- print ('single_last_loss=' , single_last_loss )
286
- print ('parallel_first_loss=' , parallel_first_loss )
287
- print ('parallel_last_loss=' , parallel_last_loss )
294
+ img = numpy .zeros (shape = [32 , 784 ], dtype = 'float32' )
295
+ label = numpy .ones (shape = [32 , 1 ], dtype = 'int64' )
296
+ single_first_loss , single_last_loss = self .check_network_convergence (
297
+ method = simple_fc_net ,
298
+ seed = 1000 ,
299
+ feed_dict = {"image" : img ,
300
+ "label" : label },
301
+ use_parallel_executor = False )
302
+ parallel_first_loss , parallel_last_loss = self .check_network_convergence (
303
+ method = simple_fc_net ,
304
+ seed = 1000 ,
305
+ feed_dict = {"image" : img ,
306
+ "label" : label },
307
+ use_parallel_executor = True )
308
+
309
+ for p_f in parallel_first_loss :
310
+ self .assertAlmostEquals (p_f , single_first_loss [0 ], delta = 1e-6 )
311
+ for p_l in parallel_last_loss :
312
+ self .assertAlmostEquals (p_l , single_last_loss [0 ], delta = 1e-6 )
288
313
289
314
def test_batchnorm_fc (self ):
290
315
self .check_network_convergence (fc_with_batchnorm )
0 commit comments