@@ -38,18 +38,23 @@ def parse_args():
38
38
default = 'resnet' ,
39
39
help = 'The model to run benchmark with.' )
40
40
parser .add_argument (
41
- '--batch_size' , type = int , default = 32 , help = 'The minibatch size.' )
41
+ '--batch_size' ,
42
+ type = int ,
43
+ default = 32 ,
44
+ help = 'The batch size on each gpu.' )
42
45
parser .add_argument (
43
46
'--learning_rate' , type = float , default = 0.001 , help = 'The learning rate.' )
44
- # TODO(wuyi): add "--use_fake_data" option back.
45
47
parser .add_argument (
46
48
'--skip_batch_num' ,
47
49
type = int ,
48
50
default = 5 ,
49
51
help = 'The first num of minibatch num to skip, for better performance test'
50
52
)
51
53
parser .add_argument (
52
- '--iterations' , type = int , default = 80 , help = 'The number of minibatches.' )
54
+ '--iterations' ,
55
+ type = int ,
56
+ default = 80 ,
57
+ help = 'The number of minibatches, set to -1 to run all batches.' )
53
58
parser .add_argument (
54
59
'--pass_num' , type = int , default = 100 , help = 'The number of passes.' )
55
60
parser .add_argument (
@@ -69,6 +74,7 @@ def parse_args():
69
74
type = int ,
70
75
default = 1 ,
71
76
help = 'If gpus > 1, will use ParallelExecutor to run, else use Executor.' )
77
+ # this option is available only for vgg and resnet.
72
78
parser .add_argument (
73
79
'--cpus' ,
74
80
type = int ,
@@ -78,7 +84,7 @@ def parse_args():
78
84
'--data_set' ,
79
85
type = str ,
80
86
default = 'flowers' ,
81
- choices = ['cifar10' , 'flowers' ],
87
+ choices = ['cifar10' , 'flowers' , 'imagenet' ],
82
88
help = 'Optional dataset for benchmark.' )
83
89
parser .add_argument (
84
90
'--infer_only' , action = 'store_true' , help = 'If set, run forward only.' )
@@ -108,6 +114,16 @@ def parse_args():
108
114
default = 'local' ,
109
115
choices = ['local' , 'pserver' , 'nccl2' ],
110
116
help = 'Choose parameter update method, can be local, pserver, nccl2.' )
117
+ parser .add_argument (
118
+ '--use_reader_op' ,
119
+ action = 'store_true' ,
120
+ help = 'Whether to use reader op, and must specify the data path if set this to true.'
121
+ )
122
+ parser .add_argument (
123
+ '--data_path' ,
124
+ type = str ,
125
+ default = "" ,
126
+ help = 'Directory that contains all the training recordio files.' )
111
127
args = parser .parse_args ()
112
128
return args
113
129
@@ -210,26 +226,50 @@ def train(avg_loss, infer_prog, optimizer, train_reader, test_reader, batch_acc,
210
226
place = core .CPUPlace () if args .device == 'CPU' else core .CUDAPlace (0 )
211
227
exe = fluid .Executor (place )
212
228
exe .run (startup_prog )
213
- feed_var_list = [
214
- var for var in train_prog .global_block ().vars .itervalues ()
215
- if var .is_data
216
- ]
217
- feeder = fluid .DataFeeder (feed_var_list , place )
229
+
230
+ if not args .use_reader_op :
231
+ feed_var_list = [
232
+ var for var in train_prog .global_block ().vars .itervalues ()
233
+ if var .is_data
234
+ ]
235
+ feeder = fluid .DataFeeder (feed_var_list , place )
218
236
219
237
iters , num_samples , start_time = 0 , 0 , time .time ()
220
238
for pass_id in range (args .pass_num ):
221
239
train_losses = []
222
- for batch_id , data in enumerate (train_reader ()):
240
+ if not args .use_reader_op :
241
+ reader_generator = train_reader ()
242
+ batch_id = 0
243
+ data = None
244
+ while True :
245
+ if not args .use_reader_op :
246
+ data = next (reader_generator , None )
247
+ if data == None :
248
+ break
249
+ if iters == args .iterations :
250
+ break
223
251
if iters == args .skip_batch_num :
224
252
start_time = time .time ()
225
253
num_samples = 0
226
- if iters == args .iterations :
227
- break
228
- loss = exe .run (train_prog ,
229
- feed = feeder .feed (data ),
230
- fetch_list = [avg_loss ])
254
+
255
+ if args .use_reader_op :
256
+ try :
257
+ loss = exe .run (train_prog , fetch_list = [avg_loss ])
258
+ except fluid .core .EnforceNotMet as ex :
259
+ break
260
+ else :
261
+ loss = exe .run (train_prog ,
262
+ feed = feeder .feed (data ),
263
+ fetch_list = [avg_loss ])
231
264
iters += 1
232
- num_samples += len (data )
265
+ batch_id += 1
266
+ # FIXME(wuyi): For use_reader_op, if the current
267
+ # pass is not the last, the last batch of this pass
268
+ # is also equal to args.batch_size.
269
+ if args .use_reader_op :
270
+ num_samples += args .batch_size * args .gpus
271
+ else :
272
+ num_samples += len (data )
233
273
train_losses .append (loss )
234
274
print ("Pass: %d, Iter: %d, Loss: %f\n " %
235
275
(pass_id , iters , np .mean (train_losses )))
@@ -250,10 +290,14 @@ def train(avg_loss, infer_prog, optimizer, train_reader, test_reader, batch_acc,
250
290
def train_parallel (avg_loss , infer_prog , optimizer , train_reader , test_reader ,
251
291
batch_acc , args , train_prog , startup_prog , nccl_id_var ,
252
292
num_trainers , trainer_id ):
253
- feed_var_list = [
254
- var for var in train_prog .global_block ().vars .itervalues ()
255
- if var .is_data
256
- ]
293
+ place = core .CPUPlace () if args .device == 'CPU' else core .CUDAPlace (0 )
294
+ if not args .use_reader_op :
295
+ feed_var_list = [
296
+ var for var in train_prog .global_block ().vars .itervalues ()
297
+ if var .is_data
298
+ ]
299
+ feeder = fluid .DataFeeder (feed_var_list , place )
300
+
257
301
# generate fake:
258
302
if args .use_fake_data :
259
303
for var in feed_var_list :
@@ -270,7 +314,6 @@ def train_parallel(avg_loss, infer_prog, optimizer, train_reader, test_reader,
270
314
"value" : 1.0 ,
271
315
"dtype" : var .dtype })
272
316
273
- place = core .CPUPlace () if args .device == 'CPU' else core .CUDAPlace (0 )
274
317
if nccl_id_var and trainer_id == 0 :
275
318
#FIXME(wuyi): wait other trainer to start listening
276
319
time .sleep (30 )
@@ -287,12 +330,21 @@ def train_parallel(avg_loss, infer_prog, optimizer, train_reader, test_reader,
287
330
num_trainers = num_trainers ,
288
331
trainer_id = trainer_id )
289
332
290
- feeder = fluid .DataFeeder (feed_var_list , place )
291
333
for pass_id in range (args .pass_num ):
292
334
num_samples = 0
293
335
iters = 0
294
336
start_time = time .time ()
295
- for batch_id , data in enumerate (train_reader ()):
337
+ if not args .use_reader_op :
338
+ reader_generator = train_reader ()
339
+ batch_id = 0
340
+ data = None
341
+ while True :
342
+ if not args .use_reader_op :
343
+ data = next (reader_generator , None )
344
+ if data == None :
345
+ break
346
+ if iters == args .iterations :
347
+ break
296
348
if args .profile and pass_id == 0 and batch_id == 5 :
297
349
profiler .start_profiler ("All" )
298
350
elif args .profile and pass_id == 0 and batch_id == 10 :
@@ -301,19 +353,26 @@ def train_parallel(avg_loss, infer_prog, optimizer, train_reader, test_reader,
301
353
if iters == args .skip_batch_num :
302
354
start_time = time .time ()
303
355
num_samples = 0
304
- if iters == args .iterations :
305
- break
306
- if args .use_fake_data :
307
- loss , = exe .run ([avg_loss .name ])
356
+ if args .use_fake_data or args .use_reader_op :
357
+ try :
358
+ loss , = exe .run ([avg_loss .name ])
359
+ except fluid .core .EnforceNotMet as ex :
360
+ break
308
361
else :
309
362
loss , = exe .run ([avg_loss .name ], feed = feeder .feed (data ))
310
363
if args .update_method == "pserver" :
311
364
exe .bcast_params ()
312
- num_samples += len (data )
365
+ if args .use_reader_op :
366
+ num_samples += args .batch_size * args .gpus
367
+ else :
368
+ num_samples += len (data )
313
369
iters += 1
314
370
if batch_id % 1 == 0 :
315
371
print ("Pass %d, batch %d, loss %s" %
316
372
(pass_id , batch_id , np .array (loss )))
373
+ batch_id += 1
374
+ if args .use_reader_op :
375
+ num_samples = num_samples * args .gpus
317
376
print_train_time (start_time , time .time (), num_samples )
318
377
if not args .no_test and batch_acc :
319
378
test_acc = test (startup_exe , infer_prog , test_reader , feeder ,
0 commit comments