Skip to content

Commit 694e894

Browse files
committed
add a base class for reader
1 parent d981333 commit 694e894

File tree

1 file changed

+34
-20
lines changed

1 file changed

+34
-20
lines changed

paddle/fluid/operators/reader/ctr_reader.cc

Lines changed: 34 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -132,31 +132,36 @@ void CTRReader::ReadThread(const std::vector<std::string>& file_list,
132132
std::vector<int64_t> batch_label;
133133

134134
MultiGzipReader reader(file_list);
135-
// read all files
136-
for (int i = 0; i < batch_size; ++i) {
137-
if (reader.HasNext()) {
138-
reader.NextLine(&line);
139-
std::unordered_map<std::string, std::vector<int64_t>> slots_to_data;
140-
int64_t label;
141-
parse_line(line, slots, &label, &slots_to_data);
142-
batch_data.push_back(slots_to_data);
143-
batch_label.push_back(label);
144-
} else {
145-
break;
135+
136+
while (reader.HasNext()) {
137+
// read all files
138+
for (int i = 0; i < batch_size; ++i) {
139+
if (reader.HasNext()) {
140+
reader.NextLine(&line);
141+
std::unordered_map<std::string, std::vector<int64_t>> slots_to_data;
142+
int64_t label;
143+
parse_line(line, slots, &label, &slots_to_data);
144+
batch_data.push_back(slots_to_data);
145+
batch_label.push_back(label);
146+
} else {
147+
break;
148+
}
146149
}
147-
}
148150

149-
std::vector<framework::LoDTensor> lod_datas;
150-
for (auto& slot : slots) {
151-
for (auto& slots_to_data : batch_data) {
151+
std::vector<framework::LoDTensor> lod_datas;
152+
153+
// first insert tensor for each slots
154+
for (auto& slot : slots) {
152155
std::vector<size_t> lod_data{0};
153156
std::vector<int64_t> batch_feasign;
154-
std::vector<int64_t> batch_label;
155157

156-
auto& feasign = slots_to_data[slot];
158+
for (size_t i = 0; i < batch_data.size(); ++i) {
159+
auto& feasign = batch_data[i][slot];
160+
161+
lod_data.push_back(lod_data.back() + feasign.size());
162+
batch_feasign.insert(feasign.end(), feasign.begin(), feasign.end());
163+
}
157164

158-
lod_data.push_back(lod_data.back() + feasign.size());
159-
batch_feasign.insert(feasign.end(), feasign.begin(), feasign.end());
160165
framework::LoDTensor lod_tensor;
161166
framework::LoD lod{lod_data};
162167
lod_tensor.set_lod(lod);
@@ -166,8 +171,17 @@ void CTRReader::ReadThread(const std::vector<std::string>& file_list,
166171
memcpy(tensor_data, batch_feasign.data(), batch_feasign.size());
167172
lod_datas.push_back(lod_tensor);
168173
}
174+
175+
// insert label tensor
176+
framework::LoDTensor label_tensor;
177+
int64_t* label_tensor_data = label_tensor.mutable_data<int64_t>(
178+
framework::make_ddim({1, static_cast<int64_t>(batch_label.size())}),
179+
platform::CPUPlace());
180+
memcpy(label_tensor_data, batch_label.data(), batch_label.size());
181+
lod_datas.push_back(label_tensor);
182+
183+
queue->Push(lod_datas);
169184
}
170-
queue->Push(lod_datas);
171185
}
172186

173187
} // namespace reader

0 commit comments

Comments
 (0)