Skip to content

Commit a28a462

Browse files
author
yi.wu
committed
Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into fluid_benchmark_support_recordioreader
2 parents 47630a4 + 2381249 commit a28a462

File tree

90 files changed

+1821
-846
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

90 files changed

+1821
-846
lines changed

benchmark/.gitignore

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,3 +7,6 @@ paddle/rnn/imdb.pkl
77
caffe/image/logs
88
tensorflow/image/logs
99
tensorflow/rnn/logs
10+
fluid/models/*.pyc
11+
fluid/logs
12+
fluid/nohup.out

benchmark/fluid/fluid_benchmark.py

Lines changed: 38 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -40,18 +40,18 @@ def parse_args():
4040
parser.add_argument(
4141
'--batch_size', type=int, default=32, help='The minibatch size.')
4242
parser.add_argument(
43-
'--learning_rate',
44-
type=float,
45-
default=0.001,
46-
help='The minibatch size.')
43+
'--learning_rate', type=float, default=0.001, help='The learning rate.')
4744
parser.add_argument(
4845
'--skip_batch_num',
4946
type=int,
5047
default=5,
5148
help='The first num of minibatch num to skip, for better performance test'
5249
)
5350
parser.add_argument(
54-
'--iterations', type=int, default=80, help='The number of minibatches.')
51+
'--iterations',
52+
type=int,
53+
default=80,
54+
help='The number of minibatches, set to -1 to run all batches.')
5555
parser.add_argument(
5656
'--pass_num', type=int, default=100, help='The number of passes.')
5757
parser.add_argument(
@@ -71,11 +71,12 @@ def parse_args():
7171
type=int,
7272
default=1,
7373
help='If gpus > 1, will use ParallelExecutor to run, else use Executor.')
74+
# this option is available only for vgg and resnet.
7475
parser.add_argument(
7576
'--data_set',
7677
type=str,
7778
default='flowers',
78-
choices=['cifar10', 'flowers'],
79+
choices=['cifar10', 'flowers', 'imagenet'],
7980
help='Optional dataset for benchmark.')
8081
parser.add_argument(
8182
'--infer_only', action='store_true', help='If set, run forward only.')
@@ -228,27 +229,32 @@ def train(avg_loss, infer_prog, optimizer, train_reader, test_reader, batch_acc,
228229
iters, num_samples, start_time = 0, 0, time.time()
229230
for pass_id in range(args.pass_num):
230231
train_losses = []
231-
for batch_id, data in enumerate(train_reader()):
232+
reader_generator = train_reader()
233+
batch_id = 0
234+
data = None
235+
while True:
236+
if not args.use_reader_op:
237+
data = next(reader_generator, None)
238+
if iters == args.iterations or data == None:
239+
break
232240
if iters == args.skip_batch_num:
233241
start_time = time.time()
234242
num_samples = 0
235-
if iters == args.iterations:
236-
break
243+
237244
if args.use_reader_op:
238245
loss = exe.run(train_prog, fetch_list=[avg_loss])
239246
else:
240247
loss = exe.run(train_prog,
241248
feed=feeder.feed(data),
242249
fetch_list=[avg_loss])
243250
iters += 1
244-
num_samples += len(data)
251+
batch_id += 1
252+
# FIXME(wuyi): last batch size maybe different
253+
num_samples += len(args.batch_size)
245254
train_losses.append(loss)
246255
print("Pass: %d, Iter: %d, Loss: %f\n" %
247256
(pass_id, iters, np.mean(train_losses)))
248-
train_elapsed = time.time() - start_time
249-
examples_per_sec = num_samples / train_elapsed
250-
print('\nTotal examples: %d, total time: %.5f, %.5f examples/sec\n' %
251-
(num_samples, train_elapsed, examples_per_sec))
257+
print_train_time(start_time, time.time(), num_samples)
252258
print("Pass: %d, Loss: %f" % (pass_id, np.mean(train_losses)))
253259
# evaluation
254260
if not args.no_test and batch_acc != None:
@@ -309,7 +315,14 @@ def train_parallel(avg_loss, infer_prog, optimizer, train_reader, test_reader,
309315
num_samples = 0
310316
iters = 0
311317
start_time = time.time()
312-
for batch_id, data in enumerate(train_reader()):
318+
reader_generator = train_reader()
319+
batch_id = 0
320+
data = None
321+
while True:
322+
if not args.use_reader_op:
323+
data = next(reader_generator, None)
324+
if iters == args.iterations or data == None:
325+
break
313326
if args.profile and pass_id == 0 and batch_id == 5:
314327
profiler.start_profiler("All")
315328
elif args.profile and pass_id == 0 and batch_id == 10:
@@ -318,8 +331,6 @@ def train_parallel(avg_loss, infer_prog, optimizer, train_reader, test_reader,
318331
if iters == args.skip_batch_num:
319332
start_time = time.time()
320333
num_samples = 0
321-
if iters == args.iterations:
322-
break
323334
# NOTE: if use reader ops, the input data is not splited to multiple cards
324335
if args.use_reader_op and iters >= args.iterations / args.gpus:
325336
break
@@ -334,12 +345,10 @@ def train_parallel(avg_loss, infer_prog, optimizer, train_reader, test_reader,
334345
if batch_id % 1 == 0:
335346
print("Pass %d, batch %d, loss %s" %
336347
(pass_id, batch_id, np.array(loss)))
337-
train_elapsed = time.time() - start_time
348+
batch_id += 1
338349
if args.use_reader_op:
339350
num_samples = num_samples * args.gpus
340-
examples_per_sec = num_samples / train_elapsed
341-
print('\nTotal examples: %d, total time: %.5f, %.5f examples/sed\n' %
342-
(num_samples, train_elapsed, examples_per_sec))
351+
print_train_time(start_time, time.time(), num_samples)
343352
if not args.no_test and batch_acc != None:
344353
test_acc = test(startup_exe, infer_prog, test_reader, feeder,
345354
batch_acc)
@@ -350,12 +359,19 @@ def train_parallel(avg_loss, infer_prog, optimizer, train_reader, test_reader,
350359
def print_arguments(args):
351360
vars(args)['use_nvprof'] = (vars(args)['use_nvprof'] and
352361
vars(args)['device'] == 'GPU')
353-
print('----------- resnet Configuration Arguments -----------')
362+
print('----------- Configuration Arguments -----------')
354363
for arg, value in sorted(vars(args).iteritems()):
355364
print('%s: %s' % (arg, value))
356365
print('------------------------------------------------')
357366

358367

368+
def print_train_time(start_time, end_time, num_samples):
369+
train_elapsed = end_time - start_time
370+
examples_per_sec = num_samples / train_elapsed
371+
print('\nTotal examples: %d, total time: %.5f, %.5f examples/sed\n' %
372+
(num_samples, train_elapsed, examples_per_sec))
373+
374+
359375
def main():
360376
args = parse_args()
361377
print_arguments(args)

benchmark/fluid/models/resnet.py

Lines changed: 23 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@
2727
import paddle.fluid as fluid
2828
import paddle.fluid.core as core
2929
import paddle.fluid.profiler as profiler
30+
from recordio_converter import imagenet_train, imagenet_test
3031

3132

3233
def conv_bn_layer(input, ch_out, filter_size, stride, padding, act='relu'):
@@ -123,13 +124,30 @@ def get_model(args):
123124
else:
124125
dshape = [32, 32, 3]
125126
model = resnet_cifar10
126-
else:
127+
train_reader = paddle.dataset.cifar.train10()
128+
test_reader = paddle.dataset.cifar.test10()
129+
elif args.data_set == "flowers":
127130
class_dim = 102
128131
if args.data_format == 'NCHW':
129132
dshape = [3, 224, 224]
130133
else:
131134
dshape = [224, 224, 3]
132135
model = resnet_imagenet
136+
train_reader = paddle.dataset.flowers.train()
137+
test_reader = paddle.dataset.flowers.test()
138+
elif args.data_set == "imagenet":
139+
class_dim = 1000
140+
if args.data_format == 'NCHW':
141+
dshape = [3, 224, 224]
142+
else:
143+
dshape = [224, 224, 3]
144+
model = resnet_imagenet
145+
if not args.data_dir:
146+
raise Exception(
147+
"Must specify --data_dir when training with imagenet")
148+
train_reader = imagenet_train(args.data_dir)
149+
test_reader = imagenet_test(args.data_dir)
150+
133151
if args.use_reader_op:
134152
filelist = [
135153
os.path.join(args.data_path, f) for f in os.listdir(args.data_path)
@@ -163,15 +181,10 @@ def get_model(args):
163181

164182
optimizer = fluid.optimizer.Momentum(learning_rate=0.01, momentum=0.9)
165183

166-
train_reader = paddle.batch(
184+
batched_train_reader = paddle.batch(
167185
paddle.reader.shuffle(
168-
paddle.dataset.cifar.train10()
169-
if args.data_set == 'cifar10' else paddle.dataset.flowers.train(),
170-
buf_size=5120),
171-
batch_size=args.batch_size)
172-
test_reader = paddle.batch(
173-
paddle.dataset.cifar.test10()
174-
if args.data_set == 'cifar10' else paddle.dataset.flowers.test(),
186+
train_reader, buf_size=5120),
175187
batch_size=args.batch_size)
188+
batched_test_reader = paddle.batch(train_reader, batch_size=args.batch_size)
176189

177-
return avg_cost, inference_program, optimizer, train_reader, test_reader, batch_acc
190+
return avg_cost, inference_program, optimizer, batched_train_reader, batched_test_reader, batch_acc

benchmark/fluid/recordio_converter.py

Lines changed: 37 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
# limitations under the License.
1414

1515
import os
16+
import random
1617
import paddle
1718
import paddle.fluid as fluid
1819
import paddle.fluid.core as core
@@ -53,6 +54,13 @@ def prepare_flowers(outpath, batch_size):
5354
[1])
5455

5556

57+
def default_mapper(sample):
58+
img, label = sample
59+
img = image.simple_transform(
60+
img, 256, 224, True, mean=[103.94, 116.78, 123.68])
61+
return img.flatten().astype('float32'), label
62+
63+
5664
def imagenet_train(data_dir):
5765
contents = os.listdir(data_dir)
5866
if set(contents) != set(
@@ -68,6 +76,8 @@ def imagenet_train(data_dir):
6876
img, lbl = l[:-1].split(" ")
6977
img2label[img] = int(lbl)
7078
imgfilelist.append(img)
79+
# shuffle all, this is slow
80+
random.shuffle(imgfilelist)
7181

7282
def train_reader():
7383
for idx, imgfile in enumerate(imgfilelist):
@@ -76,15 +86,36 @@ def train_reader():
7686
label = [img2label[imgfile], ]
7787
yield [data, label]
7888

79-
def default_mapper(sample):
80-
img, label = sample
81-
img = image.simple_transform(
82-
img, 256, 224, True, mean=[103.94, 116.78, 123.68])
83-
return img.flatten().astype('float32'), label
84-
8589
return paddle.reader.map_readers(default_mapper, train_reader)
8690

8791

92+
def imagenet_test(data_dir):
93+
contents = os.listdir(data_dir)
94+
if set(contents) != set(
95+
["train", "train.txt", "val", "val_set", "val.txt", "unzip.sh"]):
96+
raise Exception("Imagenet data contents error!")
97+
img2label = dict()
98+
imgfilelist = []
99+
with open(os.path.join(data_dir, "val.txt")) as fn:
100+
while 1:
101+
l = fn.readline()
102+
if not l:
103+
break
104+
img, lbl = l[:-1].split(" ")
105+
img2label[img] = int(lbl)
106+
imgfilelist.append(img)
107+
108+
def test_reader():
109+
for idx, imgfile in enumerate(imgfilelist):
110+
base_path = os.path.join(data_dir, "val", imgfile.split(".")[0])
111+
image_path = ".".join([base_path, "jpeg"])
112+
data = image.load_image(image_path)
113+
label = [img2label[imgfile], ]
114+
yield [data, label]
115+
116+
return paddle.reader.map_readers(default_mapper, test_reader)
117+
118+
88119
# FIXME(wuyi): delete this when https://github.com/PaddlePaddle/Paddle/pull/11066 is merged
89120
def convert_reader_to_recordio_files(
90121
filename,

benchmark/fluid/run.sh

Lines changed: 11 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
# This script benchmarking the PaddlePaddle Fluid on
33
# single thread single GPU.
44

5+
mkdir -p logs
56
#export FLAGS_fraction_of_gpu_memory_to_use=0.0
67
export CUDNN_PATH=/paddle/cudnn_v5
78

@@ -35,6 +36,7 @@ nohup stdbuf -oL nvidia-smi \
3536
--format=csv \
3637
--filename=mem.log \
3738
-l 1 &
39+
3840
# mnist
3941
# mnist gpu mnist 128
4042
FLAGS_benchmark=true stdbuf -oL python fluid_benchmark.py \
@@ -43,7 +45,7 @@ FLAGS_benchmark=true stdbuf -oL python fluid_benchmark.py \
4345
--batch_size=128 \
4446
--skip_batch_num=5 \
4547
--iterations=500 \
46-
2>&1 | tee -a mnist_gpu_128.log
48+
2>&1 | tee -a logs/mnist_gpu_128.log
4749

4850
# vgg16
4951
# gpu cifar10 128
@@ -53,7 +55,7 @@ FLAGS_benchmark=true stdbuf -oL python fluid_benchmark.py \
5355
--batch_size=128 \
5456
--skip_batch_num=5 \
5557
--iterations=30 \
56-
2>&1 | tee -a vgg16_gpu_128.log
58+
2>&1 | tee -a logs/vgg16_gpu_128.log
5759

5860
# flowers gpu 128
5961
FLAGS_benchmark=true stdbuf -oL python fluid_benchmark.py \
@@ -63,28 +65,28 @@ FLAGS_benchmark=true stdbuf -oL python fluid_benchmark.py \
6365
--data_set=flowers \
6466
--skip_batch_num=5 \
6567
--iterations=30 \
66-
2>&1 | tee -a vgg16_gpu_flowers_32.log
68+
2>&1 | tee -a logs/vgg16_gpu_flowers_32.log
6769

6870
# resnet50
6971
# resnet50 gpu cifar10 128
7072
FLAGS_benchmark=true stdbuf -oL python fluid_benchmark.py \
71-
--model=resnet50 \
73+
--model=resnet \
7274
--device=GPU \
7375
--batch_size=128 \
7476
--data_set=cifar10 \
7577
--skip_batch_num=5 \
7678
--iterations=30 \
77-
2>&1 | tee -a resnet50_gpu_128.log
79+
2>&1 | tee -a logs/resnet50_gpu_128.log
7880

7981
# resnet50 gpu flowers 64
8082
FLAGS_benchmark=true stdbuf -oL python fluid_benchmark.py \
81-
--model=resnet50 \
83+
--model=resnet \
8284
--device=GPU \
8385
--batch_size=64 \
8486
--data_set=flowers \
8587
--skip_batch_num=5 \
8688
--iterations=30 \
87-
2>&1 | tee -a resnet50_gpu_flowers_64.log
89+
2>&1 | tee -a logs/resnet50_gpu_flowers_64.log
8890

8991
# lstm
9092
# lstm gpu imdb 32 # tensorflow only support batch=32
@@ -94,7 +96,7 @@ FLAGS_benchmark=true stdbuf -oL python fluid_benchmark.py \
9496
--batch_size=32 \
9597
--skip_batch_num=5 \
9698
--iterations=30 \
97-
2>&1 | tee -a lstm_gpu_32.log
99+
2>&1 | tee -a logs/lstm_gpu_32.log
98100

99101
# seq2seq
100102
# seq2seq gpu wmb 128
@@ -104,4 +106,4 @@ FLAGS_benchmark=true stdbuf -oL python fluid_benchmark.py \
104106
--batch_size=128 \
105107
--skip_batch_num=5 \
106108
--iterations=30 \
107-
2>&1 | tee -a lstm_gpu_128.log
109+
2>&1 | tee -a logs/lstm_gpu_128.log

cmake/external/grpc.cmake

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -33,10 +33,18 @@ ELSE()
3333
SET(BUILD_CMD make HAS_SYSTEM_PROTOBUF=false -s -j ${NUM_OF_PROCESSOR} static grpc_cpp_plugin)
3434
ENDIF()
3535

36+
# FIXME(wuyi): do not build zlib cares protobuf twice, find a way to build grpc with them
3637
ExternalProject_Add(
3738
extern_grpc
3839
DEPENDS protobuf zlib
39-
URL "http://paddlepaddledeps.bj.bcebos.com/grpc.tar.xz"
40+
# NOTE(wuyi):
41+
# this package is generated by following steps:
42+
# 1. git clone -b v1.8.x https://github.com/grpc/grpc.git
43+
# 2. submodule update --init
44+
# 3. keep only zlib, cares, protobuf, boringssl under "third_party",
45+
# checkout and clean other dirs under third_party
46+
# 4. remove .git, and package the directory.
47+
URL "http://paddlepaddledeps.bj.bcebos.com/grpc-v1.8.x.tar.gz"
4048
PREFIX ${GRPC_SOURCES_DIR}
4149
UPDATE_COMMAND ""
4250
CONFIGURE_COMMAND ""
@@ -49,7 +57,6 @@ ExternalProject_Add(
4957
INSTALL_COMMAND make prefix=${GRPC_INSTALL_DIR} install
5058
)
5159

52-
# FIXME(typhoonzero): hack to get static lib path, try a better way like merge them.
5360
ADD_LIBRARY(grpc++_unsecure STATIC IMPORTED GLOBAL)
5461
SET_PROPERTY(TARGET grpc++_unsecure PROPERTY IMPORTED_LOCATION
5562
"${GRPC_INSTALL_DIR}/lib/libgrpc++_unsecure.a")

0 commit comments

Comments
 (0)