Skip to content

Commit 72b7815

Browse files
committed
Polish reader speed
1 parent e576345 commit 72b7815

File tree

4 files changed

+28
-20
lines changed

4 files changed

+28
-20
lines changed

paddle/fluid/framework/lod_tensor.cc

Lines changed: 15 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -312,19 +312,22 @@ void WriteToRecordIO(recordio::Writer *writer,
312312
writer->Write(buffer.str());
313313
}
314314

315-
std::vector<LoDTensor> ReadFromRecordIO(
316-
recordio::Scanner *scanner, const platform::DeviceContext &dev_ctx) {
317-
std::vector<LoDTensor> result;
318-
if (scanner->HasNext()) {
319-
std::istringstream sin(scanner->Next());
320-
uint32_t sz;
321-
sin.read(reinterpret_cast<char *>(&sz), sizeof(uint32_t));
322-
result.resize(sz);
323-
for (uint32_t i = 0; i < sz; ++i) {
324-
DeserializeFromStream(sin, &result[i], dev_ctx);
325-
}
315+
bool ReadFromRecordIO(recordio::Scanner *scanner,
316+
const platform::DeviceContext &dev_ctx,
317+
std::vector<LoDTensor> *result_ptr) {
318+
if (!scanner->HasNext()) {
319+
return false;
326320
}
327-
return result;
321+
std::istringstream sin(scanner->Next());
322+
uint32_t sz;
323+
sin.read(reinterpret_cast<char *>(&sz), sizeof(uint32_t));
324+
auto &result = *result_ptr;
325+
result.resize(sz);
326+
for (uint32_t i = 0; i < sz; ++i) {
327+
DeserializeFromStream(sin, &result[i], dev_ctx);
328+
}
329+
330+
return true;
328331
}
329332

330333
std::vector<LoDTensor> LoDTensor::SplitLoDTensor(

paddle/fluid/framework/lod_tensor.h

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -223,8 +223,9 @@ extern void WriteToRecordIO(recordio::Writer* writer,
223223
const std::vector<LoDTensor>& tensor,
224224
const platform::DeviceContext& dev_ctx);
225225

226-
extern std::vector<LoDTensor> ReadFromRecordIO(
227-
recordio::Scanner* scanner, const platform::DeviceContext& dev_ctx);
226+
extern bool ReadFromRecordIO(recordio::Scanner* scanner,
227+
const platform::DeviceContext& dev_ctx,
228+
std::vector<LoDTensor>* result_ptr);
228229

229230
/*
230231
* Convert between length-based LoD and offset-based LoD.

paddle/fluid/framework/lod_tensor_test.cc

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -301,11 +301,12 @@ static void TestRecordIO() {
301301
{
302302
std::unique_ptr<std::istream> stream_ptr(stream);
303303
recordio::Scanner scanner(std::move(stream_ptr));
304-
auto tensors = ReadFromRecordIO(&scanner, ctx);
304+
std::vector<framework::LoDTensor> tensors;
305+
ASSERT_TRUE(ReadFromRecordIO(&scanner, ctx, &tensors));
305306
ASSERT_EQ(tensors.size(), static_cast<size_t>(2));
306307
assert_tensor_ok(tensors[0]);
307308
assert_tensor_ok(tensors[1]);
308-
tensors = ReadFromRecordIO(&scanner, ctx);
309+
ASSERT_TRUE(ReadFromRecordIO(&scanner, ctx, &tensors));
309310
ASSERT_EQ(tensors.size(), static_cast<size_t>(2));
310311
assert_tensor_ok(tensors[0]);
311312
assert_tensor_ok(tensors[1]);

paddle/fluid/operators/reader/create_recordio_file_reader_op.cc

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -33,11 +33,14 @@ class RecordIOFileReader : public framework::FileReader {
3333

3434
protected:
3535
void ReadNextImpl(std::vector<framework::LoDTensor>* out) override {
36+
std::unique_ptr<std::lock_guard<std::mutex>> guard;
3637
if (ThreadSafe) {
37-
std::lock_guard<std::mutex> guard(*mutex_);
38-
*out = framework::ReadFromRecordIO(&scanner_, dev_ctx_);
39-
} else {
40-
*out = framework::ReadFromRecordIO(&scanner_, dev_ctx_);
38+
guard.reset(new std::lock_guard<std::mutex>(*mutex_));
39+
}
40+
41+
bool ok = framework::ReadFromRecordIO(&scanner_, dev_ctx_, out);
42+
if (!ok) {
43+
out->clear();
4144
}
4245
}
4346

0 commit comments

Comments
 (0)