Skip to content

Commit f863866

Browse files
committed
Add an unitest
1 parent 02b7d8b commit f863866

File tree

5 files changed

+82
-9
lines changed

5 files changed

+82
-9
lines changed

paddle/fluid/operators/reader/open_files_op.cc

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -94,7 +94,9 @@ void MultipleReader::EndScheduler() {
9494
available_thread_idx_->Close();
9595
buffer_->Close();
9696
waiting_file_idx_->Close();
97-
scheduler_.join();
97+
if (scheduler_.joinable()) {
98+
scheduler_.join();
99+
}
98100
delete buffer_;
99101
delete available_thread_idx_;
100102
delete waiting_file_idx_;

paddle/fluid/operators/reader/reader_op_registry.cc

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -38,17 +38,16 @@ std::unordered_map<std::string, FileReaderCreator>& FileReaderRegistry() {
3838

3939
std::unique_ptr<framework::ReaderBase> CreateReaderByFileName(
4040
const std::string& file_name, const std::vector<framework::DDim>& dims) {
41-
size_t separator_pos = file_name.find(kFileFormatSeparator);
41+
size_t separator_pos = file_name.find_last_of(kFileFormatSeparator);
4242
PADDLE_ENFORCE_NE(separator_pos, std::string::npos,
4343
"File name illegal! A legal file name should be like: "
44-
"[file_format]:[file_name] (e.g., 'recordio:data_file').");
45-
std::string filetype = file_name.substr(0, separator_pos);
46-
std::string f_name = file_name.substr(separator_pos + 1);
44+
"[file_name].[file_format] (e.g., 'data_file.recordio').");
45+
std::string filetype = file_name.substr(separator_pos + 1);
4746

4847
auto itor = FileReaderRegistry().find(filetype);
4948
PADDLE_ENFORCE(itor != FileReaderRegistry().end(),
5049
"No file reader registered for '%s' format.", filetype);
51-
framework::ReaderBase* reader = (itor->second)(f_name, dims);
50+
framework::ReaderBase* reader = (itor->second)(file_name, dims);
5251
return std::unique_ptr<framework::ReaderBase>(reader);
5352
}
5453

paddle/fluid/operators/reader/reader_op_registry.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@ namespace paddle {
2121
namespace operators {
2222
namespace reader {
2323

24-
static constexpr char kFileFormatSeparator[] = ":";
24+
static constexpr char kFileFormatSeparator[] = ".";
2525

2626
using FileReaderCreator = std::function<framework::ReaderBase*(
2727
const std::string&, const std::vector<framework::DDim>&)>;

python/paddle/fluid/layers/io.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,8 @@
2121

2222
__all__ = [
2323
'data', 'BlockGuardServ', 'ListenAndServ', 'Send', 'open_recordio_file',
24-
'read_file', 'create_shuffle_reader', 'create_double_buffer_reader'
24+
'open_files', 'read_file', 'create_shuffle_reader',
25+
'create_double_buffer_reader'
2526
]
2627

2728

@@ -307,7 +308,7 @@ def open_files(filenames, thread_num, shapes, lod_levels, dtypes):
307308
'shape_concat': shape_concat,
308309
'lod_levels': lod_levels,
309310
'ranks': ranks,
310-
'filename': filenames,
311+
'file_names': filenames,
311312
'thread_num': thread_num
312313
})
313314

Lines changed: 71 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,71 @@
1+
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
import unittest
16+
17+
import paddle.fluid as fluid
18+
import paddle.v2 as paddle
19+
import paddle.v2.dataset.mnist as mnist
20+
from shutil import copyfile
21+
22+
23+
class TestMultipleReader(unittest.TestCase):
24+
def setUp(self):
25+
# Convert mnist to recordio file
26+
with fluid.program_guard(fluid.Program(), fluid.Program()):
27+
reader = paddle.batch(mnist.train(), batch_size=32)
28+
feeder = fluid.DataFeeder(
29+
feed_list=[ # order is image and label
30+
fluid.layers.data(
31+
name='image', shape=[784]),
32+
fluid.layers.data(
33+
name='label', shape=[1], dtype='int64'),
34+
],
35+
place=fluid.CPUPlace())
36+
self.num_batch = fluid.recordio_writer.convert_reader_to_recordio_file(
37+
'./mnist_0.recordio', reader, feeder)
38+
copyfile('./mnist_0.recordio', './mnist_1.recordio')
39+
copyfile('./mnist_0.recordio', './mnist_2.recordio')
40+
print(self.num_batch)
41+
42+
def test_multiple_reader(self, thread_num=3):
43+
file_list = [
44+
'./mnist_0.recordio', './mnist_1.recordio', './mnist_2.recordio'
45+
]
46+
with fluid.program_guard(fluid.Program(), fluid.Program()):
47+
data_files = fluid.layers.open_files(
48+
filenames=file_list,
49+
thread_num=thread_num,
50+
shapes=[(-1, 784), (-1, 1)],
51+
lod_levels=[0, 0],
52+
dtypes=['float32', 'int64'])
53+
img, label = fluid.layers.read_file(data_files)
54+
55+
if fluid.core.is_compiled_with_cuda():
56+
place = fluid.CUDAPlace(0)
57+
else:
58+
place = fluid.CPUPlace()
59+
60+
exe = fluid.Executor(place)
61+
exe.run(fluid.default_startup_program())
62+
63+
batch_count = 0
64+
while not data_files.eof():
65+
img_val, = exe.run(fetch_list=[img])
66+
batch_count += 1
67+
print(batch_count)
68+
# data_files.reset()
69+
print("FUCK")
70+
71+
self.assertEqual(batch_count, self.num_batch * 3)

0 commit comments

Comments
 (0)