Skip to content

Commit 2532b92

Browse files
committed
Add more unittests and fix bugs
1 parent f863866 commit 2532b92

File tree

3 files changed

+14
-7
lines changed

3 files changed

+14
-7
lines changed

paddle/fluid/operators/reader/open_files_op.cc

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -122,6 +122,7 @@ void MultipleReader::ScheduleThreadFunc() {
122122
// No more file to read.
123123
++completed_thread_num;
124124
if (completed_thread_num == prefetchers_.size()) {
125+
buffer_->Close();
125126
break;
126127
}
127128
}
Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1 +1,4 @@
11
mnist.recordio
2+
mnist_0.recordio
3+
mnist_1.recordio
4+
mnist_2.recordio

python/paddle/fluid/tests/unittests/test_multiple_reader.py

Lines changed: 10 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -22,9 +22,10 @@
2222

2323
class TestMultipleReader(unittest.TestCase):
2424
def setUp(self):
25+
self.batch_size = 64
2526
# Convert mnist to recordio file
2627
with fluid.program_guard(fluid.Program(), fluid.Program()):
27-
reader = paddle.batch(mnist.train(), batch_size=32)
28+
reader = paddle.batch(mnist.train(), batch_size=self.batch_size)
2829
feeder = fluid.DataFeeder(
2930
feed_list=[ # order is image and label
3031
fluid.layers.data(
@@ -37,9 +38,8 @@ def setUp(self):
3738
'./mnist_0.recordio', reader, feeder)
3839
copyfile('./mnist_0.recordio', './mnist_1.recordio')
3940
copyfile('./mnist_0.recordio', './mnist_2.recordio')
40-
print(self.num_batch)
4141

42-
def test_multiple_reader(self, thread_num=3):
42+
def main(self, thread_num):
4343
file_list = [
4444
'./mnist_0.recordio', './mnist_1.recordio', './mnist_2.recordio'
4545
]
@@ -64,8 +64,11 @@ def test_multiple_reader(self, thread_num=3):
6464
while not data_files.eof():
6565
img_val, = exe.run(fetch_list=[img])
6666
batch_count += 1
67-
print(batch_count)
68-
# data_files.reset()
69-
print("FUCK")
70-
67+
self.assertLessEqual(img_val.shape[0], self.batch_size)
68+
data_files.reset()
7169
self.assertEqual(batch_count, self.num_batch * 3)
70+
71+
def test_main(self):
72+
self.main(thread_num=3) # thread number equals to file number
73+
self.main(thread_num=10) # thread number is larger than file number
74+
self.main(thread_num=2) # thread number is less than file number

0 commit comments

Comments
 (0)