Skip to content

Commit 8d14b39

Browse files
author
yi.wu
committed
follow comments
1 parent 725ea3f commit 8d14b39

File tree

6 files changed

+30
-13
lines changed

6 files changed

+30
-13
lines changed

benchmark/fluid/fluid_benchmark.py

Lines changed: 25 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,10 @@ def parse_args():
3838
default='resnet',
3939
help='The model to run benchmark with.')
4040
parser.add_argument(
41-
'--batch_size', type=int, default=32, help='The minibatch size.')
41+
'--batch_size',
42+
type=int,
43+
default=32,
44+
help='The batch size on each gpu.')
4245
parser.add_argument(
4346
'--learning_rate', type=float, default=0.001, help='The learning rate.')
4447
parser.add_argument(
@@ -229,27 +232,35 @@ def train(avg_loss, infer_prog, optimizer, train_reader, test_reader, batch_acc,
229232
iters, num_samples, start_time = 0, 0, time.time()
230233
for pass_id in range(args.pass_num):
231234
train_losses = []
232-
reader_generator = train_reader()
235+
if not args.use_reader_op:
236+
reader_generator = train_reader()
233237
batch_id = 0
234238
data = None
235239
while True:
236240
if not args.use_reader_op:
237241
data = next(reader_generator, None)
238-
if iters == args.iterations or data == None:
242+
if data == None:
243+
break
244+
if iters == args.iterations:
239245
break
240246
if iters == args.skip_batch_num:
241247
start_time = time.time()
242248
num_samples = 0
243249

244250
if args.use_reader_op:
245-
loss = exe.run(train_prog, fetch_list=[avg_loss])
251+
try:
252+
loss = exe.run(train_prog, fetch_list=[avg_loss])
253+
except fluid.core.EnforceNotMet as ex:
254+
break
246255
else:
247256
loss = exe.run(train_prog,
248257
feed=feeder.feed(data),
249258
fetch_list=[avg_loss])
250259
iters += 1
251260
batch_id += 1
252-
# FIXME(wuyi): last batch size maybe different
261+
# FIXME(wuyi): For use_reader_op, if the current
262+
# pass is not the last, the last batch of this pass
263+
# is also equal to args.batch_size.
253264
num_samples += len(args.batch_size)
254265
train_losses.append(loss)
255266
print("Pass: %d, Iter: %d, Loss: %f\n" %
@@ -315,13 +326,16 @@ def train_parallel(avg_loss, infer_prog, optimizer, train_reader, test_reader,
315326
num_samples = 0
316327
iters = 0
317328
start_time = time.time()
318-
reader_generator = train_reader()
329+
if not args.use_reader_op:
330+
reader_generator = train_reader()
319331
batch_id = 0
320332
data = None
321333
while True:
322334
if not args.use_reader_op:
323335
data = next(reader_generator, None)
324-
if iters == args.iterations or data == None:
336+
if data == None:
337+
break
338+
if iters == args.iterations:
325339
break
326340
if args.profile and pass_id == 0 and batch_id == 5:
327341
profiler.start_profiler("All")
@@ -335,7 +349,10 @@ def train_parallel(avg_loss, infer_prog, optimizer, train_reader, test_reader,
335349
if args.use_reader_op and iters >= args.iterations / args.gpus:
336350
break
337351
if args.use_fake_data or args.use_reader_op:
338-
loss, = exe.run([avg_loss.name])
352+
try:
353+
loss, = exe.run([avg_loss.name])
354+
except fluid.core.EnforceNotMet as ex:
355+
break
339356
else:
340357
loss, = exe.run([avg_loss.name], feed=feeder.feed(data))
341358
if args.update_method == "pserver":

benchmark/fluid/models/machine_translation.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -223,7 +223,7 @@ def get_model(args):
223223
train_batch_generator = paddle.batch(
224224
paddle.reader.shuffle(
225225
paddle.dataset.wmt14.train(dict_size), buf_size=1000),
226-
batch_size=args.batch_size)
226+
batch_size=args.batch_size * args.gpus)
227227

228228
test_batch_generator = paddle.batch(
229229
paddle.reader.shuffle(

benchmark/fluid/models/mnist.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -103,7 +103,7 @@ def get_model(args):
103103

104104
# Reader
105105
train_reader = paddle.batch(
106-
paddle.dataset.mnist.train(), batch_size=args.batch_size)
106+
paddle.dataset.mnist.train(), batch_size=args.batch_size * args.gpus)
107107
test_reader = paddle.batch(
108108
paddle.dataset.mnist.test(), batch_size=args.batch_size)
109109
return avg_cost, inference_program, opt, train_reader, test_reader, batch_acc

benchmark/fluid/models/resnet.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -184,7 +184,7 @@ def get_model(args):
184184
batched_train_reader = paddle.batch(
185185
paddle.reader.shuffle(
186186
train_reader, buf_size=5120),
187-
batch_size=args.batch_size)
187+
batch_size=args.batch_size * args.gpus)
188188
batched_test_reader = paddle.batch(train_reader, batch_size=args.batch_size)
189189

190190
return avg_cost, inference_program, optimizer, batched_train_reader, batched_test_reader, batch_acc

benchmark/fluid/models/stacked_dynamic_lstm.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -118,7 +118,7 @@ def gate_common(
118118
train_reader = batch(
119119
paddle.reader.shuffle(
120120
crop_sentence(imdb.train(word_dict), crop_size), buf_size=25000),
121-
batch_size=args.batch_size)
121+
batch_size=args.batch_size * args.gpus)
122122
test_reader = batch(
123123
paddle.reader.shuffle(
124124
crop_sentence(imdb.test(word_dict), crop_size), buf_size=25000),

benchmark/fluid/models/vgg.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -110,7 +110,7 @@ def get_model(args):
110110
paddle.dataset.cifar.train10()
111111
if args.data_set == 'cifar10' else paddle.dataset.flowers.train(),
112112
buf_size=5120),
113-
batch_size=args.batch_size)
113+
batch_size=args.batch_size * args.gpus)
114114
test_reader = paddle.batch(
115115
paddle.dataset.cifar.test10()
116116
if args.data_set == 'cifar10' else paddle.dataset.flowers.test(),

0 commit comments

Comments
 (0)