Skip to content

Commit 1478a5f

Browse files
committed
Make open_files use buffer
1 parent dc34eff commit 1478a5f

File tree

4 files changed

+30
-19
lines changed

4 files changed

+30
-19
lines changed

paddle/fluid/operators/reader/CMakeLists.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@ function(reader_library TARGET_NAME)
1616
endfunction()
1717

1818
cc_library(buffered_reader SRCS buffered_reader.cc DEPS reader simple_threadpool)
19-
reader_library(open_files_op SRCS open_files_op.cc)
19+
reader_library(open_files_op SRCS open_files_op.cc DEPS buffered_reader)
2020
reader_library(create_random_data_generator_op SRCS create_random_data_generator_op.cc)
2121
reader_library(create_shuffle_reader_op SRCS create_shuffle_reader_op.cc)
2222
reader_library(create_batch_reader_op SRCS create_batch_reader_op.cc)

paddle/fluid/operators/reader/open_files_op.cc

Lines changed: 12 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
#include "ThreadPool.h"
1919
#include "paddle/fluid/framework/blocking_queue.h"
2020
#include "paddle/fluid/operators/reader/blocking_queue.h"
21+
#include "paddle/fluid/operators/reader/buffered_reader.h"
2122
#include "paddle/fluid/operators/reader/reader_op_registry.h"
2223

2324
namespace paddle {
@@ -232,12 +233,17 @@ class OpenFilesOp : public framework::OperatorBase {
232233
container.reset(new OrderedReaderContainer());
233234
} else {
234235
container.reset(new PreemptiveReaderContainer(
235-
std::min(file_names.size(),
236-
static_cast<size_t>(std::thread::hardware_concurrency()))));
236+
static_cast<size_t>(Attr<int>("thread_num"))));
237237
}
238238

239-
out->Reset(
240-
std::make_shared<MultiFileReader>(file_names, std::move(container)));
239+
auto reader =
240+
std::make_shared<MultiFileReader>(file_names, std::move(container));
241+
auto buffer_size = Attr<int>("buffer_size");
242+
if (buffer_size > 1) {
243+
reader = framework::MakeDecoratedReader<BufferedReader>(
244+
reader, platform::CPUPlace(), buffer_size);
245+
}
246+
out->Reset(reader);
241247
}
242248
};
243249

@@ -253,6 +259,8 @@ class OpenFilesOpMaker : public FileReaderMakerBase {
253259
An OpenFilesOp creates a MultiFileReader, which is able to
254260
read data multi-threaded from multiple files.
255261
)DOC");
262+
AddAttr<int>("thread_num", "Number of thread to read files.");
263+
AddAttr<int>("buffer_size", "The reading buffer of these files.");
256264
}
257265
};
258266

python/paddle/fluid/layers/io.py

Lines changed: 16 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121
from ..executor import global_scope
2222
from layer_function_generator import generate_layer_fn, templatedoc
2323
import sys
24+
import multiprocessing
2425

2526
__all__ = [
2627
'data', 'BlockGuardServ', 'ListenAndServ', 'Send', 'Recv',
@@ -549,10 +550,9 @@ def open_files(filenames,
549550
shapes(list): List of tuples which declaring data shapes.
550551
lod_levels(list): List of ints which declaring data lod_level.
551552
dtypes(list): List of strs which declaring data type.
552-
thread_num(None): Deprecated argument. It will be set by open_files
553-
automatically.
554-
buffer_size(None): Deprecated argument. It will be set by open_files
555-
automatically.
553+
thread_num(None): The number of thread to read files.
554+
Default: min(len(filenames), cpu_number).
555+
buffer_size(None): The buffer size of reader. Default: 3 * thread_num
556556
pass_num(int): Number of passes to run.
557557
is_test(bool|None): Whether `open_files` used for testing or not. If it
558558
is used for testing, the order of data generated is same as the file
@@ -574,14 +574,15 @@ def open_files(filenames,
574574
# Via the reader, we can use 'read_file' layer to get data:
575575
image, label = fluid.layers.io.read_file(reader)
576576
"""
577-
if thread_num is not None:
578-
print >> sys.stderr, "thread_num parameter of open_files is " \
579-
"deprecated. It will be ignored and set " \
580-
"automatically by open_files "
581-
if buffer_size is not None:
582-
print >> sys.stderr, "buffer_size parameter of open_files is " \
583-
"deprecated. It will be ignored and set " \
584-
"automatically by open_files "
577+
if thread_num is None:
578+
thread_num = min(len(filenames), multiprocessing.cpu_count())
579+
else:
580+
thread_num = int(thread_num)
581+
582+
if buffer_size is None:
583+
buffer_size = 3 * thread_num
584+
else:
585+
buffer_size = int(buffer_size)
585586

586587
if isinstance(filenames, basestring):
587588
filenames = [filenames]
@@ -600,7 +601,9 @@ def open_files(filenames,
600601
'shape_concat': shape_concat,
601602
'lod_levels': lod_levels,
602603
'ranks': ranks,
603-
'file_names': filenames
604+
'file_names': filenames,
605+
'thread_num': thread_num,
606+
'buffer_size': buffer_size
604607
}
605608
if is_test is not None:
606609
attrs['is_test'] = is_test

python/paddle/fluid/tests/unittests/test_data_balance.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -155,7 +155,7 @@ def main_lod(self):
155155
main_program=main_prog,
156156
build_strategy=build_strategy)
157157

158-
if (parallel_exe.device_count > self.batch_size):
158+
if parallel_exe.device_count > self.batch_size:
159159
print("WARNING: Unittest TestDataBalance skipped. \
160160
For the result is not correct when device count \
161161
is larger than batch size.")

0 commit comments

Comments
 (0)