Skip to content

Commit 71c2ad4

Browse files
committed
complete read thread
1 parent 0f3ece7 commit 71c2ad4

File tree

2 files changed

+50
-11
lines changed

2 files changed

+50
-11
lines changed

paddle/fluid/operators/reader/ctr_reader.cc

Lines changed: 49 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -52,6 +52,7 @@ static inline void parse_line(
5252
std::vector<std::string> ret;
5353
string_split(line, ' ', &ret);
5454
*label = std::stoi(ret[2]) > 0;
55+
5556
for (size_t i = 3; i < ret.size(); ++i) {
5657
const std::string& item = ret[i];
5758
std::vector<std::string> slot_and_feasign;
@@ -62,6 +63,13 @@ static inline void parse_line(
6263
(*slots_to_data)[slot_and_feasign[1]].push_back(feasign);
6364
}
6465
}
66+
67+
// NOTE:: if the slot has no value, then fill [0] as it's data.
68+
for (auto& slot : slots) {
69+
if (slots_to_data->find(slot) == slots_to_data->end()) {
70+
(*slots_to_data)[slot].push_back(0);
71+
}
72+
}
6573
}
6674

6775
// class Reader {
@@ -80,9 +88,7 @@ class GzipReader {
8088

8189
bool HasNext() { return gzstream_.peek() != EOF; }
8290

83-
void NextLine(std::string* line) { // NOLINT
84-
std::getline(gzstream_, line);
85-
}
91+
void NextLine(std::string* line) { std::getline(gzstream_, *line); }
8692

8793
private:
8894
igzstream gzstream_;
@@ -108,7 +114,7 @@ class MultiGzipReader {
108114
}
109115

110116
void NextLine(std::string* line) {
111-
readers_[current_reader_index_]->NextLine(*line);
117+
readers_[current_reader_index_]->NextLine(line);
112118
}
113119

114120
private:
@@ -119,16 +125,49 @@ class MultiGzipReader {
119125
void CTRReader::ReadThread(const std::vector<std::string>& file_list,
120126
const std::vector<std::string>& slots,
121127
int batch_size,
122-
std::shared_ptr<LoDTensorBlockingQueue>* queue) {
128+
std::shared_ptr<LoDTensorBlockingQueue> queue) {
123129
std::string line;
130+
std::vector<framework::LoDTensor> read_data;
131+
132+
std::vector<std::unordered_map<std::string, std::vector<int64_t>>> batch_data;
133+
std::vector<int64_t> batch_label;
124134

125-
// read all files
126135
MultiGzipReader reader(file_list);
127-
reader.NextLine(&line);
136+
// read all files
137+
for (int i = 0; i < batch_size; ++i) {
138+
if (reader.HasNext()) {
139+
reader.NextLine(&line);
140+
std::unordered_map<std::string, std::vector<int64_t>> slots_to_data;
141+
int64_t label;
142+
parse_line(line, slots, &label, &slots_to_data);
143+
batch_data.push_back(slots_to_data);
144+
batch_label.push_back(label);
145+
} else {
146+
break;
147+
}
148+
}
128149

129-
std::unordered_map<std::string, std::vector<int64_t>> slots_to_data;
130-
int64_t label;
131-
parse_line(line, slots, &label, &slots_to_data);
150+
std::vector<framework::LoDTensor> lod_datas;
151+
for (auto& slot : slots) {
152+
for (auto& slots_to_data : batch_data) {
153+
std::vector<size_t> lod_data{0};
154+
std::vector<int64_t> batch_feasign;
155+
156+
auto& feasign = slots_to_data[slot];
157+
158+
lod_data.push_back(lod_data.back() + feasign.size());
159+
batch_feasign.insert(feasign.end(), feasign.begin(), feasign.end());
160+
framework::LoDTensor lod_tensor;
161+
framework::LoD lod{lod_data};
162+
lod_tensor.set_lod(lod);
163+
int64_t* tensor_data = lod_tensor.mutable_data<int64_t>(
164+
framework::make_ddim({1, static_cast<int64_t>(batch_feasign.size())}),
165+
platform::CPUPlace());
166+
memcpy(tensor_data, batch_feasign.data(), batch_feasign.size());
167+
lod_datas.push_back(lod_tensor);
168+
}
169+
}
170+
queue->Push(lod_datas);
132171
}
133172

134173
} // namespace reader

paddle/fluid/operators/reader/ctr_reader.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -68,7 +68,7 @@ class CTRReader : public framework::FileReader {
6868
private:
6969
void ReadThread(const std::vector<std::string>& file_list,
7070
const std::vector<std::string>& slots, int batch_size,
71-
std::shared_ptr<LoDTensorBlockingQueue>* queue);
71+
std::shared_ptr<LoDTensorBlockingQueue> queue);
7272

7373
private:
7474
std::shared_ptr<LoDTensorBlockingQueue> queue_;

0 commit comments

Comments
 (0)