Skip to content

Commit 35b79ab

Browse files
authored
Merge pull request #13983 from jacquesqiao/add-ctr-reader
Add ctr reader
2 parents b1dbbb7 + da38772 commit 35b79ab

File tree

8 files changed

+782
-0
lines changed

8 files changed

+782
-0
lines changed

CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -214,6 +214,7 @@ if (NOT WIN32)
214214
# there is no official support of warpctc, nccl, cupti in windows
215215
include(external/warpctc) # download, build, install warpctc
216216
include(cupti)
217+
include(external/gzstream)
217218
endif (NOT WIN32)
218219

219220
if(WITH_DISTRIBUTE)

cmake/external/gzstream.cmake

Lines changed: 47 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,47 @@
1+
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
#
15+
16+
IF(MOBILE_INFERENCE)
17+
return()
18+
ENDIF()
19+
20+
include (ExternalProject)
21+
22+
# NOTE: gzstream is needed when linking with ctr reader.
23+
24+
SET(GZSTREAM_SOURCES_DIR ${THIRD_PARTY_PATH}/gzstream)
25+
SET(GZSTREAM_INSTALL_DIR ${THIRD_PARTY_PATH}/install/gzstream)
26+
SET(GZSTREAM_INCLUDE_DIR "${GZSTREAM_INSTALL_DIR}/include/" CACHE PATH "gzstream include directory." FORCE)
27+
28+
ExternalProject_Add(
29+
extern_gzstream
30+
GIT_REPOSITORY "https://github.com/jacquesqiao/gzstream.git"
31+
GIT_TAG ""
32+
PREFIX ${GZSTREAM_SOURCES_DIR}
33+
UPDATE_COMMAND ""
34+
CONFIGURE_COMMAND ""
35+
BUILD_IN_SOURCE 1
36+
BUILD_COMMAND make -j8
37+
INSTALL_COMMAND mkdir -p ${GZSTREAM_INSTALL_DIR}/lib/ && mkdir -p ${GZSTREAM_INSTALL_DIR}/include/
38+
&& cp ${GZSTREAM_SOURCES_DIR}/src/extern_gzstream/libgzstream.a ${GZSTREAM_INSTALL_DIR}/lib
39+
&& cp -r ${GZSTREAM_SOURCES_DIR}/src/extern_gzstream/gzstream.h ${GZSTREAM_INSTALL_DIR}/include
40+
)
41+
42+
ADD_LIBRARY(gzstream STATIC IMPORTED GLOBAL)
43+
SET_PROPERTY(TARGET gzstream PROPERTY IMPORTED_LOCATION
44+
"${GZSTREAM_INSTALL_DIR}/lib/libgzstream.a")
45+
46+
include_directories(${GZSTREAM_INCLUDE_DIR})
47+
ADD_DEPENDENCIES(gzstream extern_gzstream zlib)

paddle/fluid/operators/reader/CMakeLists.txt

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,12 @@ reader_library(create_multi_pass_reader_op SRCS create_multi_pass_reader_op.cc)
2828
reader_library(create_custom_reader_op SRCS create_custom_reader_op.cc)
2929
reader_library(create_py_reader_op SRCS create_py_reader_op.cc)
3030

31+
if (NOT WIN32 AND NOT ON_INFER)
32+
cc_library(ctr_reader SRCS ctr_reader.cc DEPS gzstream reader zlib)
33+
cc_test(ctr_reader_test SRCS ctr_reader_test.cc DEPS ctr_reader)
34+
reader_library(create_ctr_reader_op SRCS create_ctr_reader_op.cc DEPS ctr_reader)
35+
endif ()
36+
3137
cc_test(reader_blocking_queue_test SRCS reader_blocking_queue_test.cc)
3238
# Export local libraries to parent
3339
# set(READER_LIBRARY ${LOCAL_READER_LIBS} PARENT_SCOPE)
Lines changed: 79 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,79 @@
1+
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
2+
//
3+
// Licensed under the Apache License, Version 2.0 (the "License");
4+
// you may not use this file except in compliance with the License.
5+
// You may obtain a copy of the License at
6+
//
7+
// http://www.apache.org/licenses/LICENSE-2.0
8+
//
9+
// Unless required by applicable law or agreed to in writing, software
10+
// distributed under the License is distributed on an "AS IS" BASIS,
11+
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
// See the License for the specific language governing permissions and
13+
// limitations under the License.
14+
15+
#include "paddle/fluid/operators/reader/ctr_reader.h"
16+
17+
#include "paddle/fluid/operators/reader/lod_tensor_blocking_queue.h"
18+
#include "paddle/fluid/operators/reader/reader_op_registry.h"
19+
20+
namespace paddle {
21+
namespace operators {
22+
namespace reader {
23+
24+
class CreateCTRReaderOp : public framework::OperatorBase {
25+
public:
26+
using framework::OperatorBase::OperatorBase;
27+
28+
private:
29+
void RunImpl(const framework::Scope& scope,
30+
const platform::Place& dev_place) const override {
31+
auto* out = scope.FindVar(Output("Out"))
32+
->template GetMutable<framework::ReaderHolder>();
33+
if (out->Get() != nullptr) return;
34+
35+
const std::string& queue_name = Input("blocking_queue");
36+
auto* queue_holder_var = scope.FindVar(queue_name);
37+
PADDLE_ENFORCE_NOT_NULL(
38+
queue_holder_var,
39+
"No LoDTensorBlockingQueueHolder variable with name %s found",
40+
queue_name);
41+
auto* queue_holder =
42+
queue_holder_var->template GetMutable<LoDTensorBlockingQueueHolder>();
43+
44+
int thread_num = Attr<int>("thread_num");
45+
std::vector<std::string> slots = Attr<std::vector<std::string>>("slots");
46+
int batch_size = Attr<int>("batch_size");
47+
std::vector<std::string> file_list =
48+
Attr<std::vector<std::string>>("file_list");
49+
out->Reset(std::make_shared<CTRReader>(queue_holder->GetQueue(), batch_size,
50+
thread_num, slots, file_list));
51+
}
52+
};
53+
54+
class CreateCTRReaderOpMaker : public FileReaderMakerBase {
55+
protected:
56+
void Apply() override {
57+
AddInput("blocking_queue",
58+
"Name of the `LoDTensorBlockingQueueHolder` variable");
59+
AddAttr<int>("thread_num", "the thread num to read data");
60+
AddAttr<int>("batch_size", "the batch size of read data");
61+
AddAttr<std::vector<std::string>>("file_list",
62+
"The list of files that need to read");
63+
AddAttr<std::vector<std::string>>(
64+
"slots", "the slots that should be extract from file");
65+
66+
AddComment(R"DOC(
67+
Create CTRReader to support read ctr data with cpp.
68+
)DOC");
69+
}
70+
};
71+
72+
} // namespace reader
73+
} // namespace operators
74+
} // namespace paddle
75+
76+
namespace reader = ::paddle::operators::reader;
77+
78+
REGISTER_FILE_READER_OPERATOR(create_ctr_reader, reader::CreateCTRReaderOp,
79+
reader::CreateCTRReaderOpMaker);
Lines changed: 238 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,238 @@
1+
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
2+
//
3+
// Licensed under the Apache License, Version 2.0 (the "License");
4+
// you may not use this file except in compliance with the License.
5+
// You may obtain a copy of the License at
6+
//
7+
// http://www.apache.org/licenses/LICENSE-2.0
8+
//
9+
// Unless required by applicable law or agreed to in writing, software
10+
// distributed under the License is distributed on an "AS IS" BASIS,
11+
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
// See the License for the specific language governing permissions and
13+
// limitations under the License.
14+
15+
#include "paddle/fluid/operators/reader/ctr_reader.h"
16+
17+
#include <gzstream.h>
18+
19+
#include <cstdlib>
20+
#include <fstream>
21+
#include <iostream>
22+
#include <sstream>
23+
#include <string>
24+
#include <unordered_map>
25+
26+
#include <algorithm>
27+
#include <random>
28+
29+
namespace paddle {
30+
namespace operators {
31+
namespace reader {
32+
33+
static inline void string_split(const std::string& s, const char delimiter,
34+
std::vector<std::string>* output) {
35+
size_t start = 0;
36+
size_t end = s.find_first_of(delimiter);
37+
38+
while (end <= std::string::npos) {
39+
output->emplace_back(s.substr(start, end - start));
40+
if (end == std::string::npos) {
41+
break;
42+
}
43+
start = end + 1;
44+
end = s.find_first_of(delimiter, start);
45+
}
46+
}
47+
48+
static inline void parse_line(
49+
const std::string& line,
50+
const std::unordered_map<std::string, size_t>& slot_to_index,
51+
int64_t* label,
52+
std::unordered_map<std::string, std::vector<int64_t>>* slot_to_data) {
53+
std::vector<std::string> ret;
54+
string_split(line, ' ', &ret);
55+
*label = std::stoi(ret[2]) > 0;
56+
57+
for (size_t i = 3; i < ret.size(); ++i) {
58+
const std::string& item = ret[i];
59+
std::vector<std::string> feasign_and_slot;
60+
string_split(item, ':', &feasign_and_slot);
61+
if (feasign_and_slot.size() == 2 &&
62+
slot_to_index.find(feasign_and_slot[1]) != slot_to_index.end()) {
63+
int64_t feasign = std::strtoll(feasign_and_slot[0].c_str(), NULL, 10);
64+
(*slot_to_data)[feasign_and_slot[1]].push_back(feasign);
65+
}
66+
}
67+
68+
// NOTE:: if the slot has no value, then fill [0] as it's data.
69+
for (auto& item : slot_to_index) {
70+
if (slot_to_data->find(item.first) == slot_to_data->end()) {
71+
(*slot_to_data)[item.first].push_back(0);
72+
}
73+
}
74+
}
75+
76+
class Reader {
77+
public:
78+
virtual ~Reader() {}
79+
virtual bool HasNext() = 0;
80+
virtual void NextLine(std::string* line) = 0;
81+
};
82+
83+
class GzipReader : public Reader {
84+
public:
85+
explicit GzipReader(const std::string& file_name)
86+
: gzstream_(file_name.c_str()) {}
87+
88+
~GzipReader() {}
89+
90+
bool HasNext() override { return gzstream_.peek() != EOF; }
91+
92+
void NextLine(std::string* line) override { std::getline(gzstream_, *line); }
93+
94+
private:
95+
igzstream gzstream_;
96+
};
97+
98+
class MultiGzipReader : public Reader {
99+
public:
100+
explicit MultiGzipReader(const std::vector<std::string>& file_list) {
101+
for (auto& file : file_list) {
102+
readers_.emplace_back(std::make_shared<GzipReader>(file));
103+
}
104+
}
105+
106+
bool HasNext() override {
107+
if (current_reader_index_ >= readers_.size()) {
108+
return false;
109+
}
110+
if (!readers_[current_reader_index_]->HasNext()) {
111+
current_reader_index_++;
112+
return HasNext();
113+
}
114+
return true;
115+
}
116+
117+
void NextLine(std::string* line) override {
118+
readers_[current_reader_index_]->NextLine(line);
119+
}
120+
121+
private:
122+
std::vector<std::shared_ptr<GzipReader>> readers_;
123+
size_t current_reader_index_ = 0;
124+
};
125+
126+
void MonitorThread(std::vector<ReaderThreadStatus>* thread_status,
127+
std::shared_ptr<LoDTensorBlockingQueue> queue) {
128+
VLOG(30) << "monitor thread in";
129+
bool reader_thread_is_running = true;
130+
while (reader_thread_is_running) {
131+
VLOG(30) << "reader_thread_is_running";
132+
reader_thread_is_running = false;
133+
for (size_t i = 0; i < (*thread_status).size(); ++i) {
134+
if ((*thread_status)[i] == Running) {
135+
VLOG(30) << "reader is running!";
136+
reader_thread_is_running = true;
137+
}
138+
}
139+
std::this_thread::sleep_for(std::chrono::milliseconds(1000));
140+
}
141+
VLOG(30) << "all reader thread is stopped, push empty data into queue";
142+
queue->Push({});
143+
VLOG(30) << "monitor thread exited";
144+
}
145+
146+
void ReadThread(const std::vector<std::string>& file_list,
147+
const std::vector<std::string>& slots, int batch_size,
148+
int thread_id, std::vector<ReaderThreadStatus>* thread_status,
149+
std::shared_ptr<LoDTensorBlockingQueue> queue) {
150+
VLOG(30) << "[" << thread_id << "]"
151+
<< " reader thread start! thread_id = " << thread_id;
152+
for (auto& file : file_list) {
153+
VLOG(30) << "[" << thread_id << "]"
154+
<< " file " << file;
155+
}
156+
(*thread_status)[thread_id] = Running;
157+
VLOG(30) << "set status to running";
158+
159+
std::unordered_map<std::string, size_t> slot_to_index;
160+
for (size_t i = 0; i < slots.size(); ++i) {
161+
slot_to_index[slots[i]] = i;
162+
}
163+
164+
std::string line;
165+
166+
std::vector<std::unordered_map<std::string, std::vector<int64_t>>> batch_data;
167+
std::vector<int64_t> batch_label;
168+
169+
MultiGzipReader reader(file_list);
170+
171+
VLOG(30) << "reader inited";
172+
173+
while (reader.HasNext()) {
174+
batch_data.clear();
175+
batch_data.reserve(batch_size);
176+
177+
batch_label.clear();
178+
batch_label.reserve(batch_size);
179+
180+
// read batch_size data
181+
for (int i = 0; i < batch_size; ++i) {
182+
if (reader.HasNext()) {
183+
reader.NextLine(&line);
184+
std::unordered_map<std::string, std::vector<int64_t>> slot_to_data;
185+
int64_t label;
186+
parse_line(line, slot_to_index, &label, &slot_to_data);
187+
batch_data.push_back(slot_to_data);
188+
batch_label.push_back(label);
189+
} else {
190+
break;
191+
}
192+
}
193+
194+
std::vector<framework::LoDTensor> lod_datas;
195+
196+
// first insert tensor for each slots
197+
for (auto& slot : slots) {
198+
std::vector<size_t> lod_data{0};
199+
std::vector<int64_t> batch_feasign;
200+
201+
for (size_t i = 0; i < batch_data.size(); ++i) {
202+
auto& feasign = batch_data[i][slot];
203+
lod_data.push_back(lod_data.back() + feasign.size());
204+
batch_feasign.insert(batch_feasign.end(), feasign.begin(),
205+
feasign.end());
206+
}
207+
208+
framework::LoDTensor lod_tensor;
209+
framework::LoD lod{lod_data};
210+
lod_tensor.set_lod(lod);
211+
int64_t* tensor_data = lod_tensor.mutable_data<int64_t>(
212+
framework::make_ddim({1, static_cast<int64_t>(batch_feasign.size())}),
213+
platform::CPUPlace());
214+
memcpy(tensor_data, batch_feasign.data(),
215+
batch_feasign.size() * sizeof(int64_t));
216+
lod_datas.push_back(lod_tensor);
217+
}
218+
219+
// insert label tensor
220+
framework::LoDTensor label_tensor;
221+
auto* label_tensor_data = label_tensor.mutable_data<int64_t>(
222+
framework::make_ddim({1, static_cast<int64_t>(batch_label.size())}),
223+
platform::CPUPlace());
224+
memcpy(label_tensor_data, batch_label.data(),
225+
batch_label.size() * sizeof(int64_t));
226+
lod_datas.push_back(label_tensor);
227+
228+
queue->Push(lod_datas);
229+
VLOG(40) << "push one data, queue_size=" << queue->Size();
230+
}
231+
232+
(*thread_status)[thread_id] = Stopped;
233+
VLOG(30) << "set status to stopped, thread " << thread_id << " exited";
234+
}
235+
236+
} // namespace reader
237+
} // namespace operators
238+
} // namespace paddle

0 commit comments

Comments
 (0)