Skip to content

Commit 3d677b1

Browse files
committed
fix compile errors and make OpenFilesOpMaker derived from FileReaderMakerBase
1 parent 5506225 commit 3d677b1

File tree

4 files changed

+25
-33
lines changed

4 files changed

+25
-33
lines changed

paddle/fluid/operators/reader/CMakeLists.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,11 +15,11 @@ function(reader_library TARGET_NAME)
1515
PARENT_SCOPE)
1616
endfunction()
1717

18+
reader_library(open_files_op SRCS open_files_op.cc)
1819
reader_library(create_random_data_generator_op SRCS create_random_data_generator_op.cc)
1920
reader_library(create_shuffle_reader_op SRCS create_shuffle_reader_op.cc)
2021
reader_library(create_batch_reader_op SRCS create_batch_reader_op.cc)
2122
reader_library(create_recordio_file_reader_op SRCS create_recordio_file_reader_op.cc)
2223
reader_library(create_double_buffer_reader_op SRCS create_double_buffer_reader_op.cc)
23-
reader_library(open_files_op SRCS open_files_op.cc)
2424
# Export local libraries to parent
2525
set(READER_LIBRARY ${LOCAL_READER_LIBS} PARENT_SCOPE)

paddle/fluid/operators/reader/open_files_op.cc

Lines changed: 7 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -161,31 +161,20 @@ class OpenFilesOp : public framework::OperatorBase {
161161
}
162162
};
163163

164-
class OpenFilesOpMaker : public framework::OpProtoAndCheckerMaker {
164+
class OpenFilesOpMaker : public FileReaderMakerBase {
165165
public:
166166
OpenFilesOpMaker(OpProto* op_proto, OpAttrChecker* op_checker)
167-
: OpProtoAndCheckerMaker(op_proto, op_checker) {
167+
: FileReaderMakerBase(op_proto, op_checker) {
168+
AddAttr<std::vector<std::string>>("file_names", "Files to be read.");
169+
AddAttr<int>("thread_num", "The maximal concurrent prefetch thread number.")
170+
.GreaterThan(0);
171+
168172
AddComment(R"DOC(
169173
OpenFiles Operator
170174
171175
An OpenFilesOp creates a MultipleReader, which is able to
172176
read data multi-threaded from multiple files.
173177
)DOC");
174-
AddOutput("Out", "(ReaderHolder) The created MultipleReader.");
175-
AddAttr<std::vector<int>>("shape_concat",
176-
"The concat of all data's shapes.");
177-
AddAttr<std::vector<int>>(
178-
"ranks",
179-
"The ranks of each data."
180-
"e.g."
181-
"shape_concat = [2,3,4,5,6]"
182-
"ranks = [3,2]"
183-
"It means the reader will generate two data each time,"
184-
"whose shapes are [2,3,4] and [5,6] respectively.");
185-
AddAttr<std::vector<int>>("lod_levels", "The LoD levels of each data.");
186-
AddAttr<std::vector<std::string>>("file_names", "Files to be read.");
187-
AddAttr<int>("thread_num", "The maximal concurrent prefetch thread number.")
188-
.GreaterThan(0);
189178
}
190179
};
191180

@@ -196,4 +185,4 @@ class OpenFilesOpMaker : public framework::OpProtoAndCheckerMaker {
196185
namespace reader = paddle::operators::reader;
197186

198187
REGISTER_FILE_READER_OPERATOR(open_files, reader::OpenFilesOp,
199-
reader::OpenFilesOpMaker);
188+
reader::OpenFilesOpMaker);

paddle/fluid/operators/reader/reader_op_registry.cc

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,22 @@ std::unordered_map<std::string, FileReaderCreator>& FileReaderRegistry() {
3636
return regs;
3737
}
3838

39+
std::unique_ptr<framework::ReaderBase> CreateReaderByFileName(
40+
const std::string& file_name, const std::vector<framework::DDim>& dims) {
41+
size_t separator_pos = file_name.find(kFileFormatSeparator);
42+
PADDLE_ENFORCE_NE(separator_pos, std::string::npos,
43+
"File name illegal! A legal file name should be like: "
44+
"[file_format]:[file_name] (e.g., 'recordio:data_file').");
45+
std::string filetype = file_name.substr(0, separator_pos);
46+
std::string f_name = file_name.substr(separator_pos + 1);
47+
48+
auto itor = FileReaderRegistry().find(filetype);
49+
PADDLE_ENFORCE(itor != FileReaderRegistry().end(),
50+
"No file reader registered for '%s' format.", filetype);
51+
framework::ReaderBase* reader = (itor->second)(f_name, dims);
52+
return std::unique_ptr<framework::ReaderBase>(reader);
53+
}
54+
3955
FileReaderMakerBase::FileReaderMakerBase(
4056
framework::OpProtoAndCheckerMaker::OpProto* op_proto,
4157
framework::OpAttrChecker* op_checker)

paddle/fluid/operators/reader/reader_op_registry.h

Lines changed: 1 addition & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -38,20 +38,7 @@ int RegisterFileReader(const std::string& filetype) {
3838
}
3939

4040
std::unique_ptr<framework::ReaderBase> CreateReaderByFileName(
41-
const std::string& file_name, const std::vector<framework::DDim>& dims) {
42-
size_t separator_pos = file_name.find(kFileFormatSeparator);
43-
PADDLE_ENFORCE_NE(separator_pos, std::string::npos,
44-
"File name illegal! A legal file name should be like: "
45-
"[file_format]:[file_name] (e.g., 'recordio:data_file').");
46-
std::string filetype = file_name.substr(0, separator_pos);
47-
std::string f_name = file_name.substr(separator_pos + 1);
48-
49-
auto itor = FileReaderRegistry().find(filetype);
50-
PADDLE_ENFORCE(itor != FileReaderRegistry().end(),
51-
"No file reader registered for '%s' format.", filetype);
52-
framework::ReaderBase* reader = (itor->second)(f_name, dims);
53-
return std::unique_ptr<framework::ReaderBase>(reader);
54-
}
41+
const std::string& file_name, const std::vector<framework::DDim>& dims);
5542

5643
extern std::vector<framework::DDim> RestoreShapes(
5744
const std::vector<int>& shape_concat, const std::vector<int>& ranks);

0 commit comments

Comments
 (0)