Skip to content

Commit 803e2ed

Browse files
committed
add ctr_reader_test and fix bug
1 parent c8bd521 commit 803e2ed

File tree

4 files changed

+108
-22
lines changed

4 files changed

+108
-22
lines changed

paddle/fluid/operators/reader/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@ endfunction()
1717

1818
cc_library(buffered_reader SRCS buffered_reader.cc DEPS reader simple_threadpool)
1919
cc_library(ctr_reader SRCS ctr_reader.cc DEPS reader simple_threadpool boost gzstream)
20+
cc_test(ctr_reader_test SRCS ctr_reader_test.cc DEPS ctr_reader)
2021
reader_library(open_files_op SRCS open_files_op.cc DEPS buffered_reader)
2122
reader_library(create_ctr_reader_op SRCS create_ctr_reader_op.cc DEPS ctr_reader)
2223
reader_library(create_random_data_generator_op SRCS create_random_data_generator_op.cc)

paddle/fluid/operators/reader/ctr_reader.cc

Lines changed: 51 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -46,32 +46,47 @@ static inline void string_split(const std::string& s, const char delimiter,
4646
}
4747

4848
static inline void parse_line(
49-
const std::string& line, const std::vector<std::string>& slots,
49+
const std::string& line,
50+
const std::unordered_map<std::string, size_t>& slot_to_index,
5051
int64_t* label,
51-
std::unordered_map<std::string, std::vector<int64_t>>* slots_to_data) {
52+
std::unordered_map<std::string, std::vector<int64_t>>* slot_to_data) {
5253
std::vector<std::string> ret;
5354
string_split(line, ' ', &ret);
5455
*label = std::stoi(ret[2]) > 0;
5556

5657
for (size_t i = 3; i < ret.size(); ++i) {
5758
const std::string& item = ret[i];
58-
std::vector<std::string> slot_and_feasign;
59-
string_split(item, ':', &slot_and_feasign);
60-
if (slot_and_feasign.size() == 2) {
61-
const std::string& slot = slot_and_feasign[1];
62-
int64_t feasign = std::strtoll(slot_and_feasign[0].c_str(), NULL, 10);
63-
(*slots_to_data)[slot_and_feasign[1]].push_back(feasign);
59+
std::vector<std::string> feasign_and_slot;
60+
string_split(item, ':', &feasign_and_slot);
61+
auto& slot = feasign_and_slot[1];
62+
if (feasign_and_slot.size() == 2 &&
63+
slot_to_index.find(slot) != slot_to_index.end()) {
64+
const std::string& slot = feasign_and_slot[1];
65+
int64_t feasign = std::strtoll(feasign_and_slot[0].c_str(), NULL, 10);
66+
(*slot_to_data)[feasign_and_slot[1]].push_back(feasign);
6467
}
6568
}
6669

6770
// NOTE:: if the slot has no value, then fill [0] as it's data.
68-
for (auto& slot : slots) {
69-
if (slots_to_data->find(slot) == slots_to_data->end()) {
70-
(*slots_to_data)[slot].push_back(0);
71+
for (auto& item : slot_to_index) {
72+
if (slot_to_data->find(item.first) == slot_to_data->end()) {
73+
(*slot_to_data)[item.first].push_back(0);
7174
}
7275
}
7376
}
7477

78+
static void print_map(
79+
std::unordered_map<std::string, std::vector<int64_t>>* map) {
80+
for (auto it = map->begin(); it != map->end(); ++it) {
81+
std::cout << it->first << " -> ";
82+
std::cout << "[";
83+
for (auto& i : it->second) {
84+
std::cout << i << " ";
85+
}
86+
std::cout << "]\n";
87+
}
88+
}
89+
7590
class Reader {
7691
public:
7792
virtual ~Reader() {}
@@ -126,7 +141,14 @@ void ReadThread(const std::vector<std::string>& file_list,
126141
const std::vector<std::string>& slots, int batch_size,
127142
int thread_id, std::vector<ReaderThreadStatus>* thread_status,
128143
std::shared_ptr<LoDTensorBlockingQueue> queue) {
144+
VLOG(3) << "reader thread start! thread_id = " << thread_id;
129145
(*thread_status)[thread_id] = Running;
146+
VLOG(3) << "set status to running";
147+
148+
std::unordered_map<std::string, size_t> slot_to_index;
149+
for (size_t i = 0; i < slots.size(); ++i) {
150+
slot_to_index[slots[i]] = i;
151+
}
130152

131153
std::string line;
132154

@@ -135,21 +157,29 @@ void ReadThread(const std::vector<std::string>& file_list,
135157

136158
MultiGzipReader reader(file_list);
137159

160+
VLOG(3) << "reader inited";
161+
138162
while (reader.HasNext()) {
139-
// read all files
163+
batch_data.clear();
164+
batch_label.clear();
165+
166+
// read batch_size data
140167
for (int i = 0; i < batch_size; ++i) {
141168
if (reader.HasNext()) {
142169
reader.NextLine(&line);
143-
std::unordered_map<std::string, std::vector<int64_t>> slots_to_data;
170+
std::unordered_map<std::string, std::vector<int64_t>> slot_to_data;
144171
int64_t label;
145-
parse_line(line, slots, &label, &slots_to_data);
146-
batch_data.push_back(slots_to_data);
172+
parse_line(line, slot_to_index, &label, &slot_to_data);
173+
batch_data.push_back(slot_to_data);
147174
batch_label.push_back(label);
148175
} else {
149176
break;
150177
}
151178
}
152179

180+
VLOG(3) << "read one batch, batch_size = " << batch_data.size();
181+
print_map(&batch_data[0]);
182+
153183
std::vector<framework::LoDTensor> lod_datas;
154184

155185
// first insert tensor for each slots
@@ -159,9 +189,9 @@ void ReadThread(const std::vector<std::string>& file_list,
159189

160190
for (size_t i = 0; i < batch_data.size(); ++i) {
161191
auto& feasign = batch_data[i][slot];
162-
163192
lod_data.push_back(lod_data.back() + feasign.size());
164-
batch_feasign.insert(feasign.end(), feasign.begin(), feasign.end());
193+
batch_feasign.insert(batch_feasign.end(), feasign.begin(),
194+
feasign.end());
165195
}
166196

167197
framework::LoDTensor lod_tensor;
@@ -174,6 +204,8 @@ void ReadThread(const std::vector<std::string>& file_list,
174204
lod_datas.push_back(lod_tensor);
175205
}
176206

207+
VLOG(3) << "convert data to tensor";
208+
177209
// insert label tensor
178210
framework::LoDTensor label_tensor;
179211
int64_t* label_tensor_data = label_tensor.mutable_data<int64_t>(
@@ -182,10 +214,12 @@ void ReadThread(const std::vector<std::string>& file_list,
182214
memcpy(label_tensor_data, batch_label.data(), batch_label.size());
183215
lod_datas.push_back(label_tensor);
184216

217+
VLOG(3) << "push one data";
185218
queue->Push(lod_datas);
186219
}
187220

188221
(*thread_status)[thread_id] = Stopped;
222+
VLOG(3) << "thread " << thread_id << " exited";
189223
}
190224

191225
} // namespace reader

paddle/fluid/operators/reader/ctr_reader.h

Lines changed: 11 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -47,15 +47,15 @@ class CTRReader : public framework::FileReader {
4747
PADDLE_ENFORCE(queue != nullptr, "LoDTensorBlockingQueue must not be null");
4848
PADDLE_ENFORCE_GT(file_list.size(), 0, "file list should not be empty");
4949
thread_num_ =
50-
file_list_.size() > thread_num_ ? thread_num_ : file_list_.size();
50+
file_list_.size() > thread_num ? thread_num : file_list_.size();
5151
queue_ = queue;
5252
SplitFiles();
53-
for (int i = 0; i < thread_num; ++i) {
53+
for (int i = 0; i < thread_num_; ++i) {
5454
read_thread_status_.push_back(Stopped);
5555
}
5656
}
5757

58-
~CTRReader() { queue_->Close(); }
58+
~CTRReader() { Shutdown(); }
5959

6060
void ReadNext(std::vector<framework::LoDTensor>* out) override {
6161
bool success;
@@ -74,8 +74,11 @@ class CTRReader : public framework::FileReader {
7474

7575
void Start() override {
7676
VLOG(3) << "Start reader";
77+
PADDLE_ENFORCE_EQ(read_threads_.size(), 0, "read thread should be empty!");
7778
queue_->ReOpen();
78-
for (int thread_id = 0; thread_id < file_groups_.size(); thread_id++) {
79+
VLOG(3) << "reopen success";
80+
VLOG(3) << "thread_num " << thread_num_;
81+
for (int thread_id = 0; thread_id < thread_num_; thread_id++) {
7982
read_threads_.emplace_back(new std::thread(
8083
std::bind(&ReadThread, file_groups_[thread_id], slots_, batch_size_,
8184
thread_id, &read_thread_status_, queue_)));
@@ -86,7 +89,10 @@ class CTRReader : public framework::FileReader {
8689
void SplitFiles() {
8790
file_groups_.resize(thread_num_);
8891
for (int i = 0; i < file_list_.size(); ++i) {
89-
file_groups_[i % thread_num_].push_back(file_list_[i]);
92+
auto& file_name = file_list_[i];
93+
std::ifstream f(file_name.c_str());
94+
PADDLE_ENFORCE(f.good(), "file %s not exist!", file_name);
95+
file_groups_[i % thread_num_].push_back(file_name);
9096
}
9197
}
9298

Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,45 @@
1+
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
2+
//
3+
// Licensed under the Apache License, Version 2.0 (the "License");
4+
// you may not use this file except in compliance with the License.
5+
// You may obtain a copy of the License at
6+
//
7+
// http://www.apache.org/licenses/LICENSE-2.0
8+
//
9+
// Unless required by applicable law or agreed to in writing, software
10+
// distributed under the License is distributed on an "AS IS" BASIS,
11+
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
// See the License for the specific language governing permissions and
13+
// limitations under the License.
14+
15+
#include "gtest/gtest.h"
16+
17+
#include "paddle/fluid/framework/lod_tensor.h"
18+
#include "paddle/fluid/operators/reader/blocking_queue.h"
19+
#include "paddle/fluid/operators/reader/ctr_reader.h"
20+
21+
using paddle::operators::reader::LoDTensorBlockingQueue;
22+
using paddle::operators::reader::LoDTensorBlockingQueueHolder;
23+
using paddle::operators::reader::CTRReader;
24+
25+
TEST(CTR_READER, read_data) {
26+
LoDTensorBlockingQueueHolder queue_holder;
27+
int capacity = 64;
28+
queue_holder.InitOnce(capacity, {}, false);
29+
30+
std::shared_ptr<LoDTensorBlockingQueue> queue = queue_holder.GetQueue();
31+
32+
int batch_size = 10;
33+
int thread_num = 1;
34+
std::vector<std::string> slots = {"6003", "6004"};
35+
std::vector<std::string> file_list = {
36+
"/Users/qiaolongfei/project/gzip_test/part-00000-A.gz",
37+
"/Users/qiaolongfei/project/gzip_test/part-00000-A.gz"};
38+
39+
CTRReader reader(queue, batch_size, thread_num, slots, file_list);
40+
41+
reader.Start();
42+
//
43+
// std::vector<LoDTensor> out;
44+
// reader.ReadNext(&out);
45+
}

0 commit comments

Comments
 (0)