@@ -132,31 +132,36 @@ void CTRReader::ReadThread(const std::vector<std::string>& file_list,
132
132
std::vector<int64_t > batch_label;
133
133
134
134
MultiGzipReader reader (file_list);
135
- // read all files
136
- for (int i = 0 ; i < batch_size; ++i) {
137
- if (reader.HasNext ()) {
138
- reader.NextLine (&line);
139
- std::unordered_map<std::string, std::vector<int64_t >> slots_to_data;
140
- int64_t label;
141
- parse_line (line, slots, &label, &slots_to_data);
142
- batch_data.push_back (slots_to_data);
143
- batch_label.push_back (label);
144
- } else {
145
- break ;
135
+
136
+ while (reader.HasNext ()) {
137
+ // read all files
138
+ for (int i = 0 ; i < batch_size; ++i) {
139
+ if (reader.HasNext ()) {
140
+ reader.NextLine (&line);
141
+ std::unordered_map<std::string, std::vector<int64_t >> slots_to_data;
142
+ int64_t label;
143
+ parse_line (line, slots, &label, &slots_to_data);
144
+ batch_data.push_back (slots_to_data);
145
+ batch_label.push_back (label);
146
+ } else {
147
+ break ;
148
+ }
146
149
}
147
- }
148
150
149
- std::vector<framework::LoDTensor> lod_datas;
150
- for (auto & slot : slots) {
151
- for (auto & slots_to_data : batch_data) {
151
+ std::vector<framework::LoDTensor> lod_datas;
152
+
153
+ // first insert tensor for each slots
154
+ for (auto & slot : slots) {
152
155
std::vector<size_t > lod_data{0 };
153
156
std::vector<int64_t > batch_feasign;
154
- std::vector<int64_t > batch_label;
155
157
156
- auto & feasign = slots_to_data[slot];
158
+ for (size_t i = 0 ; i < batch_data.size (); ++i) {
159
+ auto & feasign = batch_data[i][slot];
160
+
161
+ lod_data.push_back (lod_data.back () + feasign.size ());
162
+ batch_feasign.insert (feasign.end (), feasign.begin (), feasign.end ());
163
+ }
157
164
158
- lod_data.push_back (lod_data.back () + feasign.size ());
159
- batch_feasign.insert (feasign.end (), feasign.begin (), feasign.end ());
160
165
framework::LoDTensor lod_tensor;
161
166
framework::LoD lod{lod_data};
162
167
lod_tensor.set_lod (lod);
@@ -166,8 +171,17 @@ void CTRReader::ReadThread(const std::vector<std::string>& file_list,
166
171
memcpy (tensor_data, batch_feasign.data (), batch_feasign.size ());
167
172
lod_datas.push_back (lod_tensor);
168
173
}
174
+
175
+ // insert label tensor
176
+ framework::LoDTensor label_tensor;
177
+ int64_t * label_tensor_data = label_tensor.mutable_data <int64_t >(
178
+ framework::make_ddim ({1 , static_cast <int64_t >(batch_label.size ())}),
179
+ platform::CPUPlace ());
180
+ memcpy (label_tensor_data, batch_label.data (), batch_label.size ());
181
+ lod_datas.push_back (label_tensor);
182
+
183
+ queue->Push (lod_datas);
169
184
}
170
- queue->Push (lod_datas);
171
185
}
172
186
173
187
} // namespace reader
0 commit comments