@@ -52,6 +52,7 @@ static inline void parse_line(
52
52
std::vector<std::string> ret;
53
53
string_split (line, ' ' , &ret);
54
54
*label = std::stoi (ret[2 ]) > 0 ;
55
+
55
56
for (size_t i = 3 ; i < ret.size (); ++i) {
56
57
const std::string& item = ret[i];
57
58
std::vector<std::string> slot_and_feasign;
@@ -62,6 +63,13 @@ static inline void parse_line(
62
63
(*slots_to_data)[slot_and_feasign[1 ]].push_back (feasign);
63
64
}
64
65
}
66
+
67
+ // NOTE:: if the slot has no value, then fill [0] as it's data.
68
+ for (auto & slot : slots) {
69
+ if (slots_to_data->find (slot) == slots_to_data->end ()) {
70
+ (*slots_to_data)[slot].push_back (0 );
71
+ }
72
+ }
65
73
}
66
74
67
75
// class Reader {
@@ -80,9 +88,7 @@ class GzipReader {
80
88
81
89
bool HasNext () { return gzstream_.peek () != EOF; }
82
90
83
- void NextLine (std::string* line) { // NOLINT
84
- std::getline (gzstream_, line);
85
- }
91
+ void NextLine (std::string* line) { std::getline (gzstream_, *line); }
86
92
87
93
private:
88
94
igzstream gzstream_;
@@ -108,7 +114,7 @@ class MultiGzipReader {
108
114
}
109
115
110
116
void NextLine (std::string* line) {
111
- readers_[current_reader_index_]->NextLine (* line);
117
+ readers_[current_reader_index_]->NextLine (line);
112
118
}
113
119
114
120
private:
@@ -119,16 +125,49 @@ class MultiGzipReader {
119
125
void CTRReader::ReadThread (const std::vector<std::string>& file_list,
120
126
const std::vector<std::string>& slots,
121
127
int batch_size,
122
- std::shared_ptr<LoDTensorBlockingQueue>* queue) {
128
+ std::shared_ptr<LoDTensorBlockingQueue> queue) {
123
129
std::string line;
130
+ std::vector<framework::LoDTensor> read_data;
131
+
132
+ std::vector<std::unordered_map<std::string, std::vector<int64_t >>> batch_data;
133
+ std::vector<int64_t > batch_label;
124
134
125
- // read all files
126
135
MultiGzipReader reader (file_list);
127
- reader.NextLine (&line);
136
+ // read all files
137
+ for (int i = 0 ; i < batch_size; ++i) {
138
+ if (reader.HasNext ()) {
139
+ reader.NextLine (&line);
140
+ std::unordered_map<std::string, std::vector<int64_t >> slots_to_data;
141
+ int64_t label;
142
+ parse_line (line, slots, &label, &slots_to_data);
143
+ batch_data.push_back (slots_to_data);
144
+ batch_label.push_back (label);
145
+ } else {
146
+ break ;
147
+ }
148
+ }
128
149
129
- std::unordered_map<std::string, std::vector<int64_t >> slots_to_data;
130
- int64_t label;
131
- parse_line (line, slots, &label, &slots_to_data);
150
+ std::vector<framework::LoDTensor> lod_datas;
151
+ for (auto & slot : slots) {
152
+ for (auto & slots_to_data : batch_data) {
153
+ std::vector<size_t > lod_data{0 };
154
+ std::vector<int64_t > batch_feasign;
155
+
156
+ auto & feasign = slots_to_data[slot];
157
+
158
+ lod_data.push_back (lod_data.back () + feasign.size ());
159
+ batch_feasign.insert (feasign.end (), feasign.begin (), feasign.end ());
160
+ framework::LoDTensor lod_tensor;
161
+ framework::LoD lod{lod_data};
162
+ lod_tensor.set_lod (lod);
163
+ int64_t * tensor_data = lod_tensor.mutable_data <int64_t >(
164
+ framework::make_ddim ({1 , static_cast <int64_t >(batch_feasign.size ())}),
165
+ platform::CPUPlace ());
166
+ memcpy (tensor_data, batch_feasign.data (), batch_feasign.size ());
167
+ lod_datas.push_back (lod_tensor);
168
+ }
169
+ }
170
+ queue->Push (lod_datas);
132
171
}
133
172
134
173
} // namespace reader
0 commit comments