Skip to content

Commit 164f238

Browse files
committed
Polish code
1 parent f9974a4 commit 164f238

File tree

4 files changed

+6
-64
lines changed

4 files changed

+6
-64
lines changed

paddle/fluid/framework/reader.cc

Lines changed: 2 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -18,45 +18,9 @@ namespace paddle {
1818
namespace framework {
1919
ReaderBase::~ReaderBase() {}
2020

21-
std::vector<std::unique_ptr<ReaderBase>> ReaderBase::SplitReader(
22-
const platform::PlaceList &places) {
23-
std::vector<std::unique_ptr<ReaderBase>> readers;
21+
FileReader::FileReader(const std::vector<DDim> &dims) : dims_(dims) {}
2422

25-
auto mutex = std::make_shared<std::mutex>();
26-
for (size_t i = 0; i < places.size(); ++i) {
27-
readers.emplace_back(new ThreadSafeReader(this, mutex));
28-
}
29-
30-
return readers;
31-
}
32-
33-
void ThreadSafeReader::ReadNext(std::vector<LoDTensor> *out) {
34-
std::lock_guard<std::mutex> guard(*mutex_);
35-
reader_->ReadNext(out);
36-
}
37-
38-
void ThreadSafeReader::ReInit() {
39-
std::lock_guard<std::mutex> guard(*mutex_);
40-
reader_->ReInit();
41-
}
42-
43-
bool ThreadSafeReader::HasNext() const {
44-
std::lock_guard<std::mutex> guard(*mutex_);
45-
return reader_->HasNext();
46-
}
47-
48-
std::vector<std::unique_ptr<ReaderBase>> ThreadSafeReader::SplitReader(
49-
const platform::PlaceList &places) {
50-
std::vector<std::unique_ptr<ReaderBase>> readers;
51-
for (size_t i = 0; i < places.size(); ++i) {
52-
readers.emplace_back(new ThreadSafeReader(reader_, mutex_));
53-
}
54-
return readers;
55-
}
56-
57-
FileReaderBase::FileReaderBase(const std::vector<DDim> &dims) : dims_(dims) {}
58-
59-
void FileReaderBase::ReadNext(std::vector<LoDTensor> *out) {
23+
void FileReader::ReadNext(std::vector<LoDTensor> *out) {
6024
ReadNextImpl(out);
6125
PADDLE_ENFORCE_EQ(out->size(), dims_.size());
6226
for (size_t i = 0; i < dims_.size(); ++i) {

paddle/fluid/framework/reader.h

Lines changed: 2 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -33,9 +33,6 @@ class ReaderBase {
3333

3434
virtual bool HasNext() const = 0;
3535

36-
virtual std::vector<std::unique_ptr<ReaderBase>> SplitReader(
37-
const platform::PlaceList& places);
38-
3936
virtual ~ReaderBase();
4037
};
4138

@@ -53,27 +50,9 @@ class DecoratedReader : public ReaderBase {
5350
ReaderBase* reader_;
5451
};
5552

56-
class ThreadSafeReader : public DecoratedReader {
57-
public:
58-
ThreadSafeReader(ReaderBase* reader, const std::shared_ptr<std::mutex>& mutex)
59-
: DecoratedReader(reader), mutex_(mutex) {}
60-
61-
void ReadNext(std::vector<LoDTensor>* out) override;
62-
63-
void ReInit() override;
64-
65-
bool HasNext() const override;
66-
67-
std::vector<std::unique_ptr<ReaderBase>> SplitReader(
68-
const platform::PlaceList& places) override;
69-
70-
private:
71-
std::shared_ptr<std::mutex> mutex_;
72-
};
73-
74-
class FileReaderBase : public ReaderBase {
53+
class FileReader : public ReaderBase {
7554
public:
76-
explicit FileReaderBase(const std::vector<DDim>& dims);
55+
explicit FileReader(const std::vector<DDim>& dims);
7756

7857
void ReadNext(std::vector<LoDTensor>* out) override;
7958

paddle/fluid/operators/reader/create_double_buffer_reader_op.cc

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,6 @@ class DoubleBufferReader : public framework::DecoratedReader {
3939
#ifdef PADDLE_WITH_CUDA
4040
ctxs_.emplace_back(new platform::CUDADeviceContext(
4141
boost::get<platform::CUDAPlace>(place_)));
42-
#else
4342
#endif
4443
}
4544
}

paddle/fluid/operators/reader/create_recordio_file_reader_op.cc

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -18,11 +18,11 @@
1818
namespace paddle {
1919
namespace operators {
2020
namespace reader {
21-
class RecordIOFileReader : public framework::FileReaderBase {
21+
class RecordIOFileReader : public framework::FileReader {
2222
public:
2323
explicit RecordIOFileReader(const std::string& filename,
2424
const std::vector<framework::DDim>& dims)
25-
: FileReaderBase(dims),
25+
: FileReader(dims),
2626
scanner_(filename),
2727
dev_ctx_(*platform::DeviceContextPool::Instance().Get(
2828
platform::CPUPlace())) {}

0 commit comments

Comments
 (0)