Skip to content

Commit 7eedced

Browse files
committed
Polish RecordIO
1 parent cfca8a3 commit 7eedced

File tree

8 files changed

+47
-14
lines changed

8 files changed

+47
-14
lines changed

paddle/fluid/framework/reader.h

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,8 @@ class ReaderBase {
3333
std::vector<DDim> shapes() const { return shapes_; }
3434
void set_shapes(const std::vector<DDim>& shapes) { shapes_ = shapes; }
3535

36+
virtual bool HasNext() const = 0;
37+
3638
virtual ~ReaderBase() {}
3739

3840
protected:
@@ -53,6 +55,8 @@ class DecoratedReader : public ReaderBase {
5355

5456
void ReInit() override { reader_->ReInit(); }
5557

58+
bool HasNext() const override { return reader_->HasNext(); }
59+
5660
protected:
5761
ReaderBase* reader_;
5862
};
@@ -74,6 +78,8 @@ class ReaderHolder {
7478
reader_->set_shapes(shapes);
7579
}
7680

81+
bool HasNext() const { return reader_->HasNext(); }
82+
7783
private:
7884
std::unique_ptr<ReaderBase> reader_;
7985
};

paddle/fluid/operators/reader/create_double_buffer_reader_op.cc

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,8 @@ class DoubleBufferReader : public framework::DecoratedReader {
3737

3838
~DoubleBufferReader() { buffer_->Close(); }
3939

40+
bool HasNext() const override;
41+
4042
private:
4143
void PrefetchThreadFunc();
4244

@@ -106,6 +108,8 @@ void DoubleBufferReader::PrefetchThreadFunc() {
106108
}
107109
}
108110

111+
bool DoubleBufferReader::HasNext() const { PADDLE_THROW("Not Implemented"); }
112+
109113
} // namespace reader
110114
} // namespace operators
111115
} // namespace paddle

paddle/fluid/operators/reader/create_random_data_generator_op.cc

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -52,6 +52,8 @@ class RandomDataGenerator : public framework::FileReader {
5252

5353
void ReInit() override { return; }
5454

55+
bool HasNext() const override { return true; }
56+
5557
private:
5658
float min_;
5759
float max_;

paddle/fluid/recordio/chunk.cc

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -146,6 +146,7 @@ bool Chunk::Parse(std::istream& sin) {
146146
std::string buf;
147147
buf.resize(rec_len);
148148
stream.read(&buf[0], rec_len);
149+
PADDLE_ENFORCE_EQ(rec_len, stream.gcount());
149150
Add(buf);
150151
}
151152
return true;

paddle/fluid/recordio/scanner.cc

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -32,9 +32,9 @@ void Scanner::Reset() {
3232
ParseNextChunk();
3333
}
3434

35-
const std::string &Scanner::Next() {
35+
std::string Scanner::Next() {
3636
PADDLE_ENFORCE(!eof_, "StopIteration");
37-
auto &rec = cur_chunk_.Record(offset_++);
37+
auto rec = cur_chunk_.Record(offset_++);
3838
if (offset_ == cur_chunk_.NumRecords()) {
3939
ParseNextChunk();
4040
}

paddle/fluid/recordio/scanner.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@ class Scanner {
2828

2929
void Reset();
3030

31-
const std::string& Next();
31+
std::string Next();
3232

3333
bool HasNext() const;
3434

paddle/fluid/recordio/writer_scanner_test.cc

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -41,4 +41,29 @@ TEST(WriterScanner, Normal) {
4141
ASSERT_EQ("CDE", scanner.Next());
4242
ASSERT_FALSE(scanner.HasNext());
4343
}
44+
}
45+
46+
TEST(WriterScanner, TinyChunk) {
47+
std::stringstream* stream = new std::stringstream();
48+
{
49+
paddle::recordio::Writer writer(
50+
stream, paddle::recordio::Compressor::kNoCompress, 2 /*max chunk num*/);
51+
writer.Write("ABC");
52+
writer.Write("BCD");
53+
writer.Write("CDE");
54+
writer.Write("DEFG");
55+
writer.Flush();
56+
}
57+
58+
{
59+
stream->seekg(0, std::ios::beg);
60+
std::unique_ptr<std::istream> stream_ptr(stream);
61+
paddle::recordio::Scanner scanner(std::move(stream_ptr));
62+
ASSERT_TRUE(scanner.HasNext());
63+
ASSERT_EQ(scanner.Next(), "ABC");
64+
ASSERT_EQ(scanner.Next(), "BCD");
65+
ASSERT_EQ(scanner.Next(), "CDE");
66+
ASSERT_EQ(scanner.Next(), "DEFG");
67+
ASSERT_FALSE(scanner.HasNext());
68+
}
4469
}

python/paddle/fluid/tests/unittests/test_recordio_reader.py

Lines changed: 6 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -55,15 +55,10 @@ def test_main(self):
5555
exe.run(fluid.default_startup_program())
5656
avg_loss_np = []
5757

58-
for i in xrange(2): # 2 pass
59-
batch_id = 0
60-
while not data_file.eof():
61-
try:
62-
batch_id += 1
63-
tmp, = exe.run(fetch_list=[avg_loss])
64-
avg_loss_np.append(tmp)
65-
except:
66-
print batch_id
67-
break
68-
data_file.reset()
58+
# train a pass
59+
while not data_file.eof():
60+
tmp, = exe.run(fetch_list=[avg_loss])
61+
avg_loss_np.append(tmp)
62+
data_file.reset()
63+
6964
self.assertLess(avg_loss_np[-1], avg_loss_np[0])

0 commit comments

Comments
 (0)