Skip to content

Commit d37b979

Browse files
committed
update test
1 parent 4051fb3 commit d37b979

File tree

1 file changed

+36
-24
lines changed

1 file changed

+36
-24
lines changed

paddle/fluid/operators/reader/ctr_reader_test.cc

Lines changed: 36 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -55,6 +55,38 @@ static void generatedata(const std::vector<std::string>& data,
5555
PADDLE_ENFORCE(out.good(), "save file %s failed!", file_name);
5656
}
5757

58+
static inline void check_all_data(
59+
const std::vector<std::string>& ctr_data,
60+
const std::vector<std::string>& slots, const std::vector<DDim>& label_dims,
61+
const std::vector<int64_t>& label_value,
62+
const std::vector<std::tuple<LoD, std::vector<int64_t>>>& data_slot_6002,
63+
const std::vector<std::tuple<LoD, std::vector<int64_t>>>& data_slot_6003,
64+
size_t batch_num, size_t batch_size,
65+
std::shared_ptr<LoDTensorBlockingQueue> queue, CTRReader* reader) {
66+
std::vector<LoDTensor> out;
67+
for (size_t i = 0; i < batch_num; ++i) {
68+
reader->ReadNext(&out);
69+
ASSERT_EQ(out.size(), slots.size() + 1);
70+
auto& label_tensor = out.back();
71+
ASSERT_EQ(label_tensor.dims(), label_dims[i]);
72+
for (size_t j = 0; j < batch_size && i * batch_num + j < ctr_data.size();
73+
++j) {
74+
auto& label = label_tensor.data<int64_t>()[j];
75+
ASSERT_TRUE(label == 0 || label == 1);
76+
ASSERT_EQ(label, label_value[i * batch_size + j]);
77+
}
78+
auto& tensor_6002 = out[0];
79+
ASSERT_EQ(std::get<0>(data_slot_6002[i]), tensor_6002.lod());
80+
ASSERT_EQ(std::memcmp(std::get<1>(data_slot_6002[i]).data(),
81+
tensor_6002.data<int64_t>(),
82+
tensor_6002.dims()[1] * sizeof(int64_t)),
83+
0);
84+
}
85+
reader->ReadNext(&out);
86+
ASSERT_EQ(out.size(), 0);
87+
ASSERT_EQ(queue->Size(), 0);
88+
}
89+
5890
TEST(CTR_READER, read_data) {
5991
const std::vector<std::string> ctr_data = {
6092
"aaaa 1 0 0:6002 1:6003 2:6004 3:6005 4:6006 -1\n",
@@ -103,35 +135,15 @@ TEST(CTR_READER, read_data) {
103135
CTRReader reader(queue, batch_size, thread_num, slots, file_list);
104136

105137
reader.Start();
106-
107138
size_t batch_num =
108139
std::ceil(static_cast<float>(ctr_data.size()) / batch_size) * thread_num;
140+
check_all_data(ctr_data, slots, label_dims, label_value, data_slot_6002,
141+
data_slot_6003, batch_num, batch_size, queue, &reader);
109142

110-
std::vector<LoDTensor> out;
111-
for (size_t i = 0; i < batch_num; ++i) {
112-
reader.ReadNext(&out);
113-
ASSERT_EQ(out.size(), slots.size() + 1);
114-
auto& label_tensor = out.back();
115-
ASSERT_EQ(label_tensor.dims(), label_dims[i]);
116-
for (size_t j = 0; j < batch_size && i * batch_num + j < ctr_data.size();
117-
++j) {
118-
auto& label = label_tensor.data<int64_t>()[j];
119-
ASSERT_TRUE(label == 0 || label == 1);
120-
ASSERT_EQ(label, label_value[i * batch_size + j]);
121-
}
122-
auto& tensor_6002 = out[0];
123-
ASSERT_EQ(std::get<0>(data_slot_6002[i]), tensor_6002.lod());
124-
ASSERT_EQ(std::memcmp(std::get<1>(data_slot_6002[i]).data(),
125-
tensor_6002.data<int64_t>(),
126-
tensor_6002.dims()[1] * sizeof(int64_t)),
127-
0);
128-
}
129-
reader.ReadNext(&out);
130-
ASSERT_EQ(out.size(), 0);
131-
ASSERT_EQ(queue->Size(), 0);
132143
reader.Shutdown();
133144

134145
reader.Start();
146+
check_all_data(ctr_data, slots, label_dims, label_value, data_slot_6002,
147+
data_slot_6003, batch_num, batch_size, queue, &reader);
135148
reader.Shutdown();
136-
ASSERT_EQ(queue->Size(), 5);
137149
}

0 commit comments

Comments
 (0)