Skip to content

Commit e13aec6

Browse files
authored
Merge pull request #8830 from reyoung/feature/recordio_file_reader
Feature/recordio file reader
2 parents 1f757f5 + 7eedced commit e13aec6

32 files changed

+826
-45
lines changed

paddle/fluid/framework/CMakeLists.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@ endif()
2121
cc_test(eigen_test SRCS eigen_test.cc DEPS tensor)
2222

2323
nv_test(mixed_vector_test SRCS mixed_vector_test.cu DEPS place paddle_memory device_context init)
24-
cc_library(lod_tensor SRCS lod_tensor.cc DEPS ddim place tensor framework_proto)
24+
cc_library(lod_tensor SRCS lod_tensor.cc DEPS ddim place tensor framework_proto recordio)
2525
cc_test(lod_tensor_test SRCS lod_tensor_test.cc DEPS lod_tensor paddle_memory)
2626
nv_test(lod_tensor_gpu_test SRCS lod_tensor_test.cu DEPS lod_tensor init)
2727

paddle/fluid/framework/lod_tensor.cc

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,9 @@ limitations under the License. */
1919
#include "paddle/fluid/memory/memcpy.h"
2020
#include "paddle/fluid/memory/memory.h"
2121

22+
#include "paddle/fluid/recordio/scanner.h"
23+
#include "paddle/fluid/recordio/writer.h"
24+
2225
#include <stdint.h>
2326
#include <string.h>
2427
#include <algorithm>
@@ -291,6 +294,31 @@ void DeserializeFromStream(std::istream &is, LoDTensor *tensor,
291294
TensorFromStream(is, static_cast<Tensor *>(tensor), dev_ctx);
292295
}
293296

297+
void WriteToRecordIO(recordio::Writer &writer,
298+
const std::vector<LoDTensor> &tensor,
299+
const platform::DeviceContext &dev_ctx) {
300+
std::stringstream buffer;
301+
size_t sz = tensor.size();
302+
buffer.write(reinterpret_cast<const char *>(&sz), sizeof(uint32_t));
303+
for (auto &each : tensor) {
304+
SerializeToStream(buffer, each, dev_ctx);
305+
}
306+
writer.Write(buffer.str());
307+
}
308+
309+
std::vector<LoDTensor> ReadFromRecordIO(
310+
recordio::Scanner &scanner, const platform::DeviceContext &dev_ctx) {
311+
std::istringstream sin(scanner.Next());
312+
uint32_t sz;
313+
sin.read(reinterpret_cast<char *>(&sz), sizeof(uint32_t));
314+
std::vector<LoDTensor> result;
315+
result.resize(sz);
316+
for (uint32_t i = 0; i < sz; ++i) {
317+
DeserializeFromStream(sin, &result[i], dev_ctx);
318+
}
319+
return result;
320+
}
321+
294322
std::vector<LoDTensor> LoDTensor::SplitLoDTensor(
295323
const std::vector<platform::Place> places) const {
296324
check_memory_size();

paddle/fluid/framework/lod_tensor.h

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,12 @@ limitations under the License. */
2929
#include "paddle/fluid/platform/place.h"
3030

3131
namespace paddle {
32+
33+
namespace recordio {
34+
class Writer;
35+
class Scanner;
36+
}
37+
3238
namespace framework {
3339

3440
/*
@@ -209,5 +215,12 @@ void SerializeToStream(std::ostream& os, const LoDTensor& tensor,
209215
void DeserializeFromStream(std::istream& is, LoDTensor* tensor,
210216
const platform::DeviceContext& dev_ctx);
211217

218+
extern void WriteToRecordIO(recordio::Writer& writer,
219+
const std::vector<LoDTensor>& tensor,
220+
const platform::DeviceContext& dev_ctx);
221+
222+
extern std::vector<LoDTensor> ReadFromRecordIO(
223+
recordio::Scanner& scanner, const platform::DeviceContext& dev_ctx);
224+
212225
} // namespace framework
213226
} // namespace paddle

paddle/fluid/framework/lod_tensor_test.cc

Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,9 @@
1414

1515
#include "paddle/fluid/framework/lod_tensor.h"
1616

17+
#include "paddle/fluid/recordio/scanner.h"
18+
#include "paddle/fluid/recordio/writer.h"
19+
1720
#include <glog/logging.h>
1821
#include <gtest/gtest.h>
1922
#include <algorithm>
@@ -224,5 +227,43 @@ TEST(LoD, CheckAbsLoD) {
224227
abs_lod0.push_back(std::vector<size_t>({0}));
225228
ASSERT_FALSE(CheckAbsLoD(abs_lod0));
226229
}
230+
231+
TEST(LoDTensor, RecordIO) {
232+
LoDTensor tensor;
233+
int* tmp = tensor.mutable_data<int>(make_ddim({4, 5}), platform::CPUPlace());
234+
for (int i = 0; i < 20; ++i) {
235+
tmp[i] = i;
236+
}
237+
238+
std::stringstream* stream = new std::stringstream();
239+
auto& ctx =
240+
*platform::DeviceContextPool::Instance().Get(platform::CPUPlace());
241+
{
242+
recordio::Writer writer(stream, recordio::Compressor::kSnappy);
243+
WriteToRecordIO(writer, {tensor, tensor}, ctx);
244+
WriteToRecordIO(writer, {tensor, tensor}, ctx);
245+
writer.Flush();
246+
}
247+
248+
auto assert_tensor_ok = [](const LoDTensor& tensor) {
249+
for (int i = 0; i < 20; ++i) {
250+
ASSERT_EQ(tensor.data<int>()[i], i);
251+
}
252+
};
253+
254+
{
255+
std::unique_ptr<std::istream> stream_ptr(stream);
256+
recordio::Scanner scanner(std::move(stream_ptr));
257+
auto tensors = ReadFromRecordIO(scanner, ctx);
258+
ASSERT_EQ(tensors.size(), 2);
259+
assert_tensor_ok(tensors[0]);
260+
assert_tensor_ok(tensors[1]);
261+
tensors = ReadFromRecordIO(scanner, ctx);
262+
ASSERT_EQ(tensors.size(), 2);
263+
assert_tensor_ok(tensors[0]);
264+
assert_tensor_ok(tensors[1]);
265+
}
266+
}
267+
227268
} // namespace framework
228269
} // namespace paddle

paddle/fluid/framework/reader.h

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,8 @@ class ReaderBase {
3333
std::vector<DDim> shapes() const { return shapes_; }
3434
void set_shapes(const std::vector<DDim>& shapes) { shapes_ = shapes; }
3535

36+
virtual bool HasNext() const = 0;
37+
3638
virtual ~ReaderBase() {}
3739

3840
protected:
@@ -53,6 +55,8 @@ class DecoratedReader : public ReaderBase {
5355

5456
void ReInit() override { reader_->ReInit(); }
5557

58+
bool HasNext() const override { return reader_->HasNext(); }
59+
5660
protected:
5761
ReaderBase* reader_;
5862
};
@@ -87,6 +91,8 @@ class ReaderHolder {
8791
reader_->set_shapes(shapes);
8892
}
8993

94+
bool HasNext() const { return reader_->HasNext(); }
95+
9096
private:
9197
std::unique_ptr<ReaderBase> reader_;
9298
};

paddle/fluid/operators/detail/safe_ref.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,8 @@ limitations under the License. */
1414

1515
#pragma once
1616

17+
#include "paddle/fluid/platform/enforce.h"
18+
1719
namespace paddle {
1820
namespace operators {
1921
namespace detail {
Lines changed: 23 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,24 @@
11
cc_library(reader_op_registry SRCS reader_op_registry.cc DEPS operator op_registry reader)
2-
op_library(create_random_data_generator_op SRCS create_random_data_generator_op.cc DEPS reader_op_registry)
3-
op_library(create_shuffle_reader_op SRCS create_shuffle_reader_op.cc DEPS reader_op_registry)
4-
op_library(create_batch_reader_op SRCS create_batch_reader_op.cc DEPS reader_op_registry)
5-
op_library(create_double_buffer_reader_op SRCS create_double_buffer_reader_op.cc DEPS reader_op_registry)
6-
set(READER_LIBRARY create_random_data_generator_op create_shuffle_reader_op create_batch_reader_op create_double_buffer_reader_op PARENT_SCOPE)
2+
set(LOCAL_READER_LIBS)
3+
4+
function(reader_library TARGET_NAME)
5+
set(oneValueArgs "")
6+
set(multiValueArgs SRCS DEPS)
7+
set(options "")
8+
set(common_deps reader_op_registry)
9+
cmake_parse_arguments(reader_library "${options}" "${oneValueArgs}"
10+
"${multiValueArgs}" ${ARGN})
11+
op_library(${TARGET_NAME} SRCS ${reader_library_SRCS} DEPS ${common_deps} ${reader_library_DEPS})
12+
set(LOCAL_READER_LIBS
13+
${TARGET_NAME}
14+
${LOCAL_READER_LIBS}
15+
PARENT_SCOPE)
16+
endfunction()
17+
18+
reader_library(create_random_data_generator_op SRCS create_random_data_generator_op.cc)
19+
reader_library(create_shuffle_reader_op SRCS create_shuffle_reader_op.cc)
20+
reader_library(create_batch_reader_op SRCS create_batch_reader_op.cc)
21+
reader_library(create_recordio_file_reader_op SRCS create_recordio_file_reader_op.cc)
22+
reader_library(create_double_buffer_reader_op SRCS create_double_buffer_reader_op.cc)
23+
# Export local libraries to parent
24+
set(READER_LIBRARY ${LOCAL_READER_LIBS} PARENT_SCOPE)

paddle/fluid/operators/reader/create_double_buffer_reader_op.cc

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,8 @@ class DoubleBufferReader : public framework::DecoratedReader {
3737

3838
~DoubleBufferReader() { buffer_->Close(); }
3939

40+
bool HasNext() const override;
41+
4042
private:
4143
void PrefetchThreadFunc();
4244

@@ -106,6 +108,8 @@ void DoubleBufferReader::PrefetchThreadFunc() {
106108
}
107109
}
108110

111+
bool DoubleBufferReader::HasNext() const { PADDLE_THROW("Not Implemented"); }
112+
109113
} // namespace reader
110114
} // namespace operators
111115
} // namespace paddle

paddle/fluid/operators/reader/create_random_data_generator_op.cc

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -52,6 +52,8 @@ class RandomDataGenerator : public framework::FileReader {
5252

5353
void ReInit() override { return; }
5454

55+
bool HasNext() const override { return true; }
56+
5557
private:
5658
float min_;
5759
float max_;
Lines changed: 87 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,87 @@
1+
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
2+
//
3+
// Licensed under the Apache License, Version 2.0 (the "License");
4+
// you may not use this file except in compliance with the License.
5+
// You may obtain a copy of the License at
6+
//
7+
// http://www.apache.org/licenses/LICENSE-2.0
8+
//
9+
// Unless required by applicable law or agreed to in writing, software
10+
// distributed under the License is distributed on an "AS IS" BASIS,
11+
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
// See the License for the specific language governing permissions and
13+
// limitations under the License.
14+
15+
#include "paddle/fluid/operators/reader/reader_op_registry.h"
16+
#include "paddle/fluid/recordio/scanner.h"
17+
18+
namespace paddle {
19+
namespace operators {
20+
namespace reader {
21+
class RecordIOFileReader : public framework::FileReader {
22+
public:
23+
RecordIOFileReader(const std::string& filename,
24+
const std::vector<framework::DDim>& shapes)
25+
: FileReader(shapes),
26+
scanner_(filename),
27+
dev_ctx_(*platform::DeviceContextPool::Instance().Get(
28+
platform::CPUPlace())) {}
29+
30+
void ReadNext(std::vector<framework::LoDTensor>* out) override {
31+
*out = framework::ReadFromRecordIO(scanner_, dev_ctx_);
32+
}
33+
34+
bool HasNext() const override { return scanner_.HasNext(); }
35+
36+
void ReInit() override { scanner_.Reset(); }
37+
38+
private:
39+
recordio::Scanner scanner_;
40+
const platform::DeviceContext& dev_ctx_;
41+
};
42+
43+
class CreateRecordIOReaderOp : public framework::OperatorBase {
44+
public:
45+
using framework::OperatorBase::OperatorBase;
46+
47+
private:
48+
void RunImpl(const framework::Scope& scope,
49+
const platform::Place& dev_place) const override {
50+
const auto& shape_concat = Attr<std::vector<int>>("shape_concat");
51+
const auto& ranks = Attr<std::vector<int>>("ranks");
52+
PADDLE_ENFORCE(!shape_concat.empty() && !ranks.empty());
53+
PADDLE_ENFORCE_EQ(std::accumulate(ranks.begin(), ranks.end(), 0),
54+
int(shape_concat.size()),
55+
"The accumulate of all ranks should be equal to the "
56+
"shape concat's length.");
57+
std::vector<framework::DDim> shapes = RestoreShapes(shape_concat, ranks);
58+
std::string filename = Attr<std::string>("filename");
59+
60+
auto* out = scope.FindVar(Output("Out"))
61+
->template GetMutable<framework::ReaderHolder>();
62+
out->Reset(new RecordIOFileReader(filename, shapes));
63+
}
64+
};
65+
66+
class CreateRecordIOReaderOpMaker : public FileReaderMakerBase {
67+
public:
68+
CreateRecordIOReaderOpMaker(OpProto* op_proto, OpAttrChecker* op_checker)
69+
: FileReaderMakerBase(op_proto, op_checker) {
70+
AddAttr<std::string>("filename", "The filename of record io reader");
71+
AddComment(R"DOC(
72+
CreateRecordIOReader Operator
73+
74+
Create a reader from a record io file
75+
)DOC");
76+
}
77+
};
78+
79+
} // namespace reader
80+
} // namespace operators
81+
} // namespace paddle
82+
83+
namespace reader = paddle::operators::reader;
84+
85+
REGISTER_FILE_READER_OPERATOR(create_recordio_file_reader,
86+
reader::CreateRecordIOReaderOp,
87+
reader::CreateRecordIOReaderOpMaker);

0 commit comments

Comments
 (0)