@@ -203,10 +203,14 @@ def check_network_convergence(self,
203
203
iter = 10 ,
204
204
batch_size = None ,
205
205
allow_op_delay = False ,
206
- feed_dict = {}):
206
+ feed_dict = {},
207
+ random_seed = None ,
208
+ use_parallel_executor = True ):
207
209
main = fluid .Program ()
208
210
startup = fluid .Program ()
209
211
with fluid .program_guard (main , startup ):
212
+ if seed is not None :
213
+ startup .random_seed (random_seed )
210
214
loss = method (use_feed = len (feed_dict ) > 0 )
211
215
adam = fluid .optimizer .Adam ()
212
216
adam .minimize (loss )
@@ -217,7 +221,11 @@ def check_network_convergence(self,
217
221
startup_exe = fluid .Executor (place )
218
222
startup_exe .run (startup )
219
223
220
- exe = fluid .ParallelExecutor (True , loss_name = loss .name )
224
+ if use_parallel_executor :
225
+ exe = fluid .ParallelExecutor (True , loss_name = loss .name )
226
+ else :
227
+ exe = fluid .Executor (place = place )
228
+
221
229
if batch_size is not None :
222
230
batch_size *= fluid .core .get_cuda_device_count ()
223
231
begin = time .time ()
@@ -238,6 +246,7 @@ def check_network_convergence(self,
238
246
239
247
print first_loss , last_loss
240
248
# self.assertGreater(first_loss[0], last_loss[0])
249
+ return first_loss , last_loss
241
250
242
251
243
252
class TestMNIST (TestParallelExecutorBase ):
@@ -267,6 +276,17 @@ def test_simple_fc(self):
267
276
simple_fc_net , feed_dict = {"image" : img ,
268
277
"label" : label })
269
278
279
+ def test_simple_fc_parallel_accuracy (self ):
280
+ single_first_loss , single_last_loss = self .check_network_convergence (
281
+ simple_fc_net , random_seed = 0 , use_parallel_executor = False )
282
+ parallel_first_loss , parallel_last_loss = self .check_network_convergence (
283
+ simple_fc_net , random_seed = 0 , use_parallel_executor = True )
284
+ print ("FUCK" )
285
+ print ('single_first_loss=' , single_first_loss )
286
+ print ('single_last_loss=' , single_last_loss )
287
+ print ('parallel_first_loss=' , parallel_first_loss )
288
+ print ('parallel_last_loss=' , parallel_last_loss )
289
+
270
290
def test_batchnorm_fc (self ):
271
291
self .check_network_convergence (fc_with_batchnorm )
272
292
img = numpy .zeros (shape = [32 , 784 ], dtype = 'float32' )
0 commit comments