Skip to content

Commit 88fa9c2

Browse files
authored
Merge pull request #11267 from JiayiFeng/fix_reader_bug
Fix a multi-thread bug in readers
2 parents 9300212 + 499dbe0 commit 88fa9c2

File tree

7 files changed

+15
-10
lines changed

7 files changed

+15
-10
lines changed

paddle/fluid/framework/reader.h

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -35,14 +35,15 @@ class ReaderBase {
3535

3636
class DecoratedReader : public ReaderBase {
3737
public:
38-
explicit DecoratedReader(ReaderBase* reader) : ReaderBase(), reader_(reader) {
38+
explicit DecoratedReader(const std::shared_ptr<ReaderBase>& reader)
39+
: ReaderBase(), reader_(reader) {
3940
PADDLE_ENFORCE_NOT_NULL(reader_);
4041
}
4142

4243
void ReInit() override { reader_->ReInit(); }
4344

4445
protected:
45-
ReaderBase* reader_;
46+
std::shared_ptr<ReaderBase> reader_;
4647
};
4748

4849
class FileReader : public ReaderBase {
@@ -64,7 +65,7 @@ class ReaderHolder {
6465
public:
6566
void Reset(ReaderBase* reader) { reader_.reset(reader); }
6667

67-
ReaderBase* Get() const { return reader_.get(); }
68+
std::shared_ptr<ReaderBase> Get() const { return reader_; }
6869

6970
void ReadNext(std::vector<LoDTensor>* out) {
7071
PADDLE_ENFORCE_NOT_NULL(reader_);
@@ -76,7 +77,7 @@ class ReaderHolder {
7677
}
7778

7879
private:
79-
std::unique_ptr<ReaderBase> reader_;
80+
std::shared_ptr<ReaderBase> reader_;
8081
};
8182

8283
} // namespace framework

paddle/fluid/operators/reader/create_batch_reader_op.cc

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@ namespace reader {
2020

2121
class BatchReader : public framework::DecoratedReader {
2222
public:
23-
BatchReader(ReaderBase* reader, int batch_size)
23+
BatchReader(const std::shared_ptr<ReaderBase>& reader, int batch_size)
2424
: DecoratedReader(reader), batch_size_(batch_size) {
2525
buffer_.reserve(batch_size_);
2626
}

paddle/fluid/operators/reader/create_custom_reader_op.cc

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,8 @@ namespace reader {
2222

2323
class CustomReader : public framework::DecoratedReader {
2424
public:
25-
CustomReader(ReaderBase* reader, const framework::BlockDesc& sub_block,
25+
CustomReader(const std::shared_ptr<ReaderBase>& reader,
26+
const framework::BlockDesc& sub_block,
2627
const std::vector<std::string>& source_var_names,
2728
const std::vector<std::string>& sink_var_names)
2829
: DecoratedReader(reader),

paddle/fluid/operators/reader/create_double_buffer_reader_op.cc

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,8 @@ static constexpr size_t kChannelSize = 1; // kCacheSize - 2
3434
class DoubleBufferReader : public framework::DecoratedReader {
3535
public:
3636
explicit DoubleBufferReader(
37-
ReaderBase* reader, platform::Place target_place = platform::CPUPlace())
37+
const std::shared_ptr<ReaderBase>& reader,
38+
platform::Place target_place = platform::CPUPlace())
3839
: DecoratedReader(reader), place_(target_place) {
3940
cpu_tensor_cache_.resize(kCacheSize);
4041
gpu_tensor_cache_.resize(kCacheSize);

paddle/fluid/operators/reader/create_multi_pass_reader_op.cc

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@ namespace reader {
2121

2222
class MultiPassReader : public framework::DecoratedReader {
2323
public:
24-
MultiPassReader(ReaderBase* reader, int pass_num)
24+
MultiPassReader(const std::shared_ptr<ReaderBase>& reader, int pass_num)
2525
: DecoratedReader(reader), pass_num_(pass_num), pass_count_(0) {}
2626

2727
void ReadNext(std::vector<framework::LoDTensor>* out) override {

paddle/fluid/operators/reader/create_shuffle_reader_op.cc

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,8 @@ namespace reader {
2323

2424
class ShuffleReader : public framework::DecoratedReader {
2525
public:
26-
ShuffleReader(ReaderBase* reader, size_t buffer_size, size_t seed = 0)
26+
ShuffleReader(const std::shared_ptr<ReaderBase>& reader, size_t buffer_size,
27+
size_t seed = 0)
2728
: DecoratedReader(reader), buffer_size_(buffer_size), seed_(seed) {
2829
VLOG(10) << "Create shuffle reader of " << reader_;
2930
if (seed_ == 0) {

paddle/fluid/operators/reader/create_threaded_reader_op.cc

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,8 @@ namespace reader {
2121

2222
class ThreadedReader : public framework::DecoratedReader {
2323
public:
24-
explicit ThreadedReader(ReaderBase* reader) : DecoratedReader(reader) {}
24+
explicit ThreadedReader(const std::shared_ptr<ReaderBase>& reader)
25+
: DecoratedReader(reader) {}
2526

2627
void ReadNext(std::vector<framework::LoDTensor>* out) override {
2728
std::lock_guard<std::mutex> lock(mutex_);

0 commit comments

Comments
 (0)