Skip to content

Commit a4f397f

Browse files
committed
add an unittest
1 parent 91b6d60 commit a4f397f

File tree

2 files changed

+68
-1
lines changed

2 files changed

+68
-1
lines changed

python/paddle/fluid/layers/io.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,8 @@
2121

2222
__all__ = [
2323
'data', 'BlockGuardServ', 'ListenAndServ', 'Send', 'open_recordio_file',
24-
'read_file', 'create_shuffle_reader', 'create_double_buffer_reader'
24+
'read_file', 'create_shuffle_reader', 'create_double_buffer_reader',
25+
'create_multi_pass_reader'
2526
]
2627

2728

Lines changed: 66 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,66 @@
1+
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
import unittest
16+
17+
import paddle.fluid as fluid
18+
import paddle.v2 as paddle
19+
import paddle.v2.dataset.mnist as mnist
20+
21+
22+
class TestMultipleReader(unittest.TestCase):
23+
def setUp(self):
24+
self.batch_size = 64
25+
self.pass_num = 3
26+
# Convert mnist to recordio file
27+
with fluid.program_guard(fluid.Program(), fluid.Program()):
28+
data_file = paddle.batch(mnist.train(), batch_size=self.batch_size)
29+
feeder = fluid.DataFeeder(
30+
feed_list=[
31+
fluid.layers.data(
32+
name='image', shape=[784]),
33+
fluid.layers.data(
34+
name='label', shape=[1], dtype='int64'),
35+
],
36+
place=fluid.CPUPlace())
37+
self.num_batch = fluid.recordio_writer.convert_reader_to_recordio_file(
38+
'./mnist.recordio', data_file, feeder)
39+
40+
def test_main(self):
41+
with fluid.program_guard(fluid.Program(), fluid.Program()):
42+
data_file = fluid.layers.open_recordio_file(
43+
filename='./mnist.recordio',
44+
shapes=[(-1, 784), (-1, 1)],
45+
lod_levels=[0, 0],
46+
dtypes=['float32', 'int64'])
47+
data_file = fluid.layers.create_multi_pass_reader(
48+
reader=data_file, pass_num=self.pass_num)
49+
img, label = fluid.layers.read_file(data_file)
50+
51+
if fluid.core.is_compiled_with_cuda():
52+
place = fluid.CUDAPlace(0)
53+
else:
54+
place = fluid.CPUPlace()
55+
56+
exe = fluid.Executor(place)
57+
exe.run(fluid.default_startup_program())
58+
59+
batch_count = 0
60+
while not data_file.eof():
61+
img_val, = exe.run(fetch_list=[img])
62+
batch_count += 1
63+
self.assertLessEqual(img_val.shape[0], self.batch_size)
64+
print(batch_count)
65+
data_file.reset()
66+
self.assertEqual(batch_count, self.num_batch * self.pass_num)

0 commit comments

Comments
 (0)