@@ -55,6 +55,38 @@ static void generatedata(const std::vector<std::string>& data,
55
55
PADDLE_ENFORCE (out.good (), " save file %s failed!" , file_name);
56
56
}
57
57
58
+ static inline void check_all_data (
59
+ const std::vector<std::string>& ctr_data,
60
+ const std::vector<std::string>& slots, const std::vector<DDim>& label_dims,
61
+ const std::vector<int64_t >& label_value,
62
+ const std::vector<std::tuple<LoD, std::vector<int64_t >>>& data_slot_6002,
63
+ const std::vector<std::tuple<LoD, std::vector<int64_t >>>& data_slot_6003,
64
+ size_t batch_num, size_t batch_size,
65
+ std::shared_ptr<LoDTensorBlockingQueue> queue, CTRReader* reader) {
66
+ std::vector<LoDTensor> out;
67
+ for (size_t i = 0 ; i < batch_num; ++i) {
68
+ reader->ReadNext (&out);
69
+ ASSERT_EQ (out.size (), slots.size () + 1 );
70
+ auto & label_tensor = out.back ();
71
+ ASSERT_EQ (label_tensor.dims (), label_dims[i]);
72
+ for (size_t j = 0 ; j < batch_size && i * batch_num + j < ctr_data.size ();
73
+ ++j) {
74
+ auto & label = label_tensor.data <int64_t >()[j];
75
+ ASSERT_TRUE (label == 0 || label == 1 );
76
+ ASSERT_EQ (label, label_value[i * batch_size + j]);
77
+ }
78
+ auto & tensor_6002 = out[0 ];
79
+ ASSERT_EQ (std::get<0 >(data_slot_6002[i]), tensor_6002.lod ());
80
+ ASSERT_EQ (std::memcmp (std::get<1 >(data_slot_6002[i]).data (),
81
+ tensor_6002.data <int64_t >(),
82
+ tensor_6002.dims ()[1 ] * sizeof (int64_t )),
83
+ 0 );
84
+ }
85
+ reader->ReadNext (&out);
86
+ ASSERT_EQ (out.size (), 0 );
87
+ ASSERT_EQ (queue->Size (), 0 );
88
+ }
89
+
58
90
TEST (CTR_READER, read_data) {
59
91
const std::vector<std::string> ctr_data = {
60
92
" aaaa 1 0 0:6002 1:6003 2:6004 3:6005 4:6006 -1\n " ,
@@ -103,35 +135,15 @@ TEST(CTR_READER, read_data) {
103
135
CTRReader reader (queue, batch_size, thread_num, slots, file_list);
104
136
105
137
reader.Start ();
106
-
107
138
size_t batch_num =
108
139
std::ceil (static_cast <float >(ctr_data.size ()) / batch_size) * thread_num;
140
+ check_all_data (ctr_data, slots, label_dims, label_value, data_slot_6002,
141
+ data_slot_6003, batch_num, batch_size, queue, &reader);
109
142
110
- std::vector<LoDTensor> out;
111
- for (size_t i = 0 ; i < batch_num; ++i) {
112
- reader.ReadNext (&out);
113
- ASSERT_EQ (out.size (), slots.size () + 1 );
114
- auto & label_tensor = out.back ();
115
- ASSERT_EQ (label_tensor.dims (), label_dims[i]);
116
- for (size_t j = 0 ; j < batch_size && i * batch_num + j < ctr_data.size ();
117
- ++j) {
118
- auto & label = label_tensor.data <int64_t >()[j];
119
- ASSERT_TRUE (label == 0 || label == 1 );
120
- ASSERT_EQ (label, label_value[i * batch_size + j]);
121
- }
122
- auto & tensor_6002 = out[0 ];
123
- ASSERT_EQ (std::get<0 >(data_slot_6002[i]), tensor_6002.lod ());
124
- ASSERT_EQ (std::memcmp (std::get<1 >(data_slot_6002[i]).data (),
125
- tensor_6002.data <int64_t >(),
126
- tensor_6002.dims ()[1 ] * sizeof (int64_t )),
127
- 0 );
128
- }
129
- reader.ReadNext (&out);
130
- ASSERT_EQ (out.size (), 0 );
131
- ASSERT_EQ (queue->Size (), 0 );
132
143
reader.Shutdown ();
133
144
134
145
reader.Start ();
146
+ check_all_data (ctr_data, slots, label_dims, label_value, data_slot_6002,
147
+ data_slot_6003, batch_num, batch_size, queue, &reader);
135
148
reader.Shutdown ();
136
- ASSERT_EQ (queue->Size (), 5 );
137
149
}
0 commit comments