Skip to content

Commit 42e65a2

Browse files
authored
Merge pull request #8791 from reyoung/feature/extract_reader_ops
Extract create_reader_op to three files
2 parents 87568cf + 4690b9c commit 42e65a2

11 files changed

+553
-415
lines changed

paddle/fluid/framework/reader.cc

Lines changed: 0 additions & 87 deletions
Original file line numberDiff line numberDiff line change
@@ -25,92 +25,5 @@ DDim ReaderBase::shape(size_t idx) const {
2525
return shapes_[idx];
2626
}
2727

28-
void ShuffleReader::ReadNext(std::vector<LoDTensor>* out) {
29-
if (iteration_pos_ >= buffer_.size()) {
30-
// Reload buffer with new data
31-
buffer_.clear();
32-
buffer_.reserve(buffer_size_);
33-
for (int i = 0; i < buffer_size_; ++i) {
34-
if (reader_->HasNext()) {
35-
buffer_.push_back(std::vector<LoDTensor>());
36-
reader_->ReadNext(&buffer_.back());
37-
} else {
38-
break;
39-
}
40-
}
41-
// TODO(fengjiayi): 'std::random_shuffle' can be very slow. It needs to be
42-
// optimize.
43-
std::random_shuffle(buffer_.begin(), buffer_.end());
44-
iteration_pos_ = 0;
45-
}
46-
out->clear();
47-
if (!buffer_.empty()) {
48-
std::swap(*out, buffer_[iteration_pos_++]);
49-
}
50-
// if buffer_ is empty, the 'out' will return as an empty vector.
51-
}
52-
53-
void BatchReader::ReadNext(std::vector<LoDTensor>* out) {
54-
buffer_.clear();
55-
buffer_.reserve(batch_size_);
56-
for (int i = 0; i < batch_size_; ++i) {
57-
if (reader_->HasNext()) {
58-
buffer_.push_back(std::vector<LoDTensor>());
59-
reader_->ReadNext(&buffer_.back());
60-
} else {
61-
break;
62-
}
63-
}
64-
// Concat instances
65-
out->clear();
66-
if (buffer_.empty()) {
67-
// if buffer_ is empty, the 'out' will return as an empty vector.
68-
return;
69-
}
70-
int out_num = buffer_[0].size();
71-
out->reserve(out_num);
72-
for (int j = 0; j < out_num; ++j) {
73-
// Merge shape and check date type
74-
std::type_index batch_type = buffer_[0][j].type();
75-
DDim batch_shape = buffer_[0][j].dims();
76-
for (size_t i = 1; i < buffer_.size(); ++i) {
77-
std::type_index ins_type = buffer_[i][j].type();
78-
DDim ins_shape = buffer_[i][j].dims();
79-
PADDLE_ENFORCE_EQ(batch_type, ins_type);
80-
PADDLE_ENFORCE_EQ(slice_ddim(batch_shape, 1, batch_shape.size()),
81-
slice_ddim(ins_shape, 1, ins_shape.size()));
82-
PADDLE_ENFORCE_GT(ins_shape[0], 0);
83-
batch_shape[0] += ins_shape[0];
84-
}
85-
86-
LoDTensor out_tensor;
87-
out_tensor.Resize(batch_shape);
88-
out_tensor.mutable_data(platform::CPUPlace(), batch_type);
89-
int64_t dst_offset = 0;
90-
91-
// Merge lod and data
92-
LoD batch_lod;
93-
for (size_t i = 0; i < buffer_.size(); ++i) {
94-
DDim ins_shape = buffer_[i][j].dims();
95-
LoD ins_lod = buffer_[i][j].lod();
96-
if (i == 0) {
97-
batch_lod = ins_lod;
98-
} else {
99-
PADDLE_ENFORCE_EQ(batch_lod.size(), ins_lod.size());
100-
for (size_t level_idx = 0; level_idx < batch_lod.size(); ++level_idx) {
101-
auto& lod_level = batch_lod[level_idx];
102-
for (size_t k = 1; k < ins_lod[level_idx].size(); ++k) {
103-
lod_level.push_back(ins_lod[level_idx][k] + lod_level.back());
104-
}
105-
}
106-
}
107-
Tensor dst = out_tensor.Slice(dst_offset, dst_offset + ins_shape[0]);
108-
TensorCopy(buffer_[i][j], platform::CPUPlace(), &dst);
109-
dst_offset += ins_shape[0];
110-
}
111-
out_tensor.set_lod(batch_lod);
112-
out->push_back(out_tensor);
113-
}
114-
}
11528
} // namespace framework
11629
} // namespace paddle

paddle/fluid/framework/reader.h

Lines changed: 2 additions & 77 deletions
Original file line numberDiff line numberDiff line change
@@ -60,83 +60,8 @@ class DecoratedReader : public ReaderBase {
6060
ReaderBase* reader_;
6161
};
6262

63-
// file readers
64-
65-
template <typename T>
66-
class RandomDataGenerator : public FileReader {
67-
public:
68-
RandomDataGenerator(const std::vector<DDim>& shapes, float min, float max)
69-
: FileReader(shapes), min_(min), max_(max) {
70-
PADDLE_ENFORCE_LE(
71-
min, max, "'min' shouldn't be greater than 'max'.(%f vs %f)", min, max);
72-
unsigned int seed = std::random_device()();
73-
engine_.seed(seed);
74-
dist_ = std::uniform_real_distribution<float>(min_, max_);
75-
}
76-
77-
void ReadNext(std::vector<LoDTensor>* out) override {
78-
out->clear();
79-
out->reserve(shapes_.size());
80-
for (const DDim& shape : shapes_) {
81-
PADDLE_ENFORCE_GE(
82-
shape.size(), 2,
83-
"The rank of reader's output data should be 2 at least.(Now it's %d)",
84-
shape.size());
85-
LoDTensor out_tensor;
86-
out_tensor.Resize(shape);
87-
T* data = out_tensor.mutable_data<T>(platform::CPUPlace());
88-
int64_t numel = product(shape);
89-
for (int64_t i = 0; i < numel; ++i) {
90-
data[i] = dist_(engine_);
91-
}
92-
out->push_back(out_tensor);
93-
}
94-
}
95-
96-
bool HasNext() const override { return true; }
97-
98-
void ReInit() override { return; }
99-
100-
private:
101-
float min_;
102-
float max_;
103-
std::minstd_rand engine_;
104-
std::uniform_real_distribution<float> dist_;
105-
};
106-
107-
// decorated readers
108-
109-
class ShuffleReader : public DecoratedReader {
110-
public:
111-
ShuffleReader(ReaderBase* reader, int buffer_size)
112-
: DecoratedReader(reader), buffer_size_(buffer_size), iteration_pos_(0) {
113-
buffer_.reserve(buffer_size);
114-
}
115-
116-
void ReadNext(std::vector<LoDTensor>* out) override;
117-
118-
private:
119-
int buffer_size_;
120-
std::vector<std::vector<LoDTensor>> buffer_;
121-
size_t iteration_pos_;
122-
};
123-
124-
class BatchReader : public DecoratedReader {
125-
public:
126-
BatchReader(ReaderBase* reader, int batch_size)
127-
: DecoratedReader(reader), batch_size_(batch_size) {
128-
buffer_.reserve(batch_size_);
129-
}
130-
131-
void ReadNext(std::vector<LoDTensor>* out) override;
132-
133-
private:
134-
int batch_size_;
135-
std::vector<std::vector<LoDTensor>> buffer_;
136-
};
137-
138-
// The ReaderHolder is used as readers' unified wrapper,
139-
// making it easier to access different type readers in Variables.
63+
// The ReaderHolder is used as reader' unified wrapper,
64+
// making it easier to access different type reader in Variables.
14065
class ReaderHolder {
14166
public:
14267
void Reset(ReaderBase* reader) { reader_.reset(reader); }

paddle/fluid/operators/CMakeLists.txt

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -70,7 +70,7 @@ function(op_library TARGET)
7070
endif()
7171

7272
# Define operators that don't need pybind here.
73-
foreach(manual_pybind_op "net_op" "compare_op" "logical_op" "nccl_op" "tensor_array_read_write_op" "create_reader_op")
73+
foreach(manual_pybind_op "net_op" "compare_op" "logical_op" "nccl_op" "tensor_array_read_write_op")
7474
if ("${TARGET}" STREQUAL "${manual_pybind_op}")
7575
set(pybind_flag 1)
7676
endif()
@@ -128,8 +128,8 @@ else()
128128
set(DEPS_OPS ${DEPS_OPS} nccl_op)
129129
endif()
130130

131+
add_subdirectory(detail)
131132
if(WITH_DISTRIBUTE)
132-
add_subdirectory(detail)
133133
set(DISTRIBUTE_DEPS sendrecvop_grpc grpc++_unsecure grpc_unsecure gpr cares zlib protobuf)
134134
set(DISTRIBUTE_COMPILE_FLAGS "-Wno-non-virtual-dtor -Wno-error=non-virtual-dtor -Wno-error=delete-non-virtual-dtor")
135135
op_library(send_op DEPS ${DISTRIBUTE_DEPS})
@@ -170,7 +170,6 @@ op_library(recurrent_op DEPS executor)
170170
op_library(warpctc_op DEPS dynload_warpctc sequence_padding sequence_scale)
171171
op_library(cos_sim_op DEPS cos_sim_functor)
172172
op_library(parallel_do_op DEPS executor)
173-
op_library(create_reader_op DEPS reader)
174173

175174
if (WITH_GPU)
176175
op_library(conv_op DEPS vol2col depthwise_conv)
@@ -189,7 +188,12 @@ list(REMOVE_ITEM GENERAL_OPS ${DEPS_OPS})
189188
foreach(src ${GENERAL_OPS})
190189
op_library(${src})
191190
endforeach()
192-
file(APPEND ${pybind_file} "USE_OP(less_than);\nUSE_OP(logical_and);\nUSE_NO_KERNEL_OP(read_from_array);\nUSE_NO_KERNEL_OP(create_random_data_generator);\n")
191+
file(APPEND ${pybind_file} "USE_OP(less_than);\nUSE_OP(logical_and);\nUSE_NO_KERNEL_OP(read_from_array);\n")
192+
193+
add_subdirectory(reader)
194+
foreach(src ${READER_LIBRARY})
195+
set(OP_LIBRARY ${src} ${OP_LIBRARY})
196+
endforeach()
193197

194198
set(GLOB_OP_LIB ${OP_LIBRARY} CACHE INTERNAL "Global OP library")
195199

0 commit comments

Comments
 (0)