@@ -84,12 +84,17 @@ inline double GetCurrentUS() {
84
84
}
85
85
86
86
template <class T >
87
- std::string Vector2Str (const std::vector<T>& input ) {
87
+ std::string Vector2Str (const std::vector<std::vector<T>>& inputs ) {
88
88
std::stringstream ss;
89
- for (int i = 0 ; i < input.size () - 1 ; i++) {
90
- ss << input[i] << " ," ;
89
+ for (int j = 0 ; j < inputs.size (); j++) {
90
+ auto input = inputs[j];
91
+ for (int i = 0 ; i < input.size () - 1 ; i++) {
92
+ ss << input[i] << " ," ;
93
+ }
94
+ ss << input.back ();
95
+ ss << " ;" ;
91
96
}
92
- ss << input. back ();
97
+
93
98
return ss.str ();
94
99
}
95
100
@@ -102,6 +107,22 @@ T ShapeProduction(const std::vector<T>& shape) {
102
107
return num;
103
108
}
104
109
110
+ std::vector<int64_t > get_shape (const std::string& str_shape) {
111
+ std::vector<int64_t > shape;
112
+ std::string tmp_str = str_shape;
113
+ while (!tmp_str.empty ()) {
114
+ int dim = atoi (tmp_str.data ());
115
+ shape.push_back (dim);
116
+ size_t next_offset = tmp_str.find (" ," );
117
+ if (next_offset == std::string::npos) {
118
+ break ;
119
+ } else {
120
+ tmp_str = tmp_str.substr (next_offset + 1 );
121
+ }
122
+ }
123
+ return shape;
124
+ }
125
+
105
126
std::vector<int64_t > GetInputShape (const std::string& str_shape) {
106
127
std::vector<int64_t > shape;
107
128
std::string tmp_str = str_shape;
@@ -118,6 +139,21 @@ std::vector<int64_t> GetInputShape(const std::string& str_shape) {
118
139
return shape;
119
140
}
120
141
142
+ std::vector<std::string> split_string (const std::string& str_in) {
143
+ std::vector<std::string> str_out;
144
+ std::string tmp_str = str_in;
145
+ while (!tmp_str.empty ()) {
146
+ size_t next_offset = tmp_str.find (" :" );
147
+ str_out.push_back (tmp_str.substr (0 , next_offset));
148
+ if (next_offset == std::string::npos) {
149
+ break ;
150
+ } else {
151
+ tmp_str = tmp_str.substr (next_offset + 1 );
152
+ }
153
+ }
154
+ return str_out;
155
+ }
156
+
121
157
void PrintUsage () {
122
158
std::string help_info =
123
159
" Usage: \n "
@@ -190,10 +226,22 @@ void OutputOptModel(const std::string& save_optimized_model_dir) {
190
226
191
227
void Run (const std::string& model_path,
192
228
const std::string& model_name,
193
- const std::vector<int64_t >& input_shape ) {
229
+ const std::vector<std::vector< int64_t >>& input_shapes ) {
194
230
int threads = FLAGS_threads;
195
231
int power_mode = FLAGS_power_mode;
196
- std::string input_data_path = FLAGS_input_data_path;
232
+ std::vector<std::string> input_data_path =
233
+ paddle::lite_api::split_string (FLAGS_input_data_path);
234
+ if (!input_data_path.empty () &&
235
+ input_data_path.size () != input_shapes.size ()) {
236
+ LOG (FATAL) << " Error: the input_data_path's size is not consistent with "
237
+ " input_shape's size , the input_data_path contains "
238
+ << input_data_path.size ()
239
+ << " input data path, while input_shape contains "
240
+ << input_shapes.size ()
241
+ << " input shape. (members in input_data_path and input_shape "
242
+ " sholud be separated by :)" ;
243
+ }
244
+
197
245
int warmup = FLAGS_warmup;
198
246
int repeats = FLAGS_repeats;
199
247
std::string result_path = FLAGS_result_path;
@@ -208,21 +256,24 @@ void Run(const std::string& model_path,
208
256
auto predictor = lite_api::CreatePaddlePredictor (config);
209
257
210
258
// set input
211
- auto input_tensor = predictor->GetInput (0 );
212
- input_tensor->Resize (input_shape);
213
- auto input_data = input_tensor->mutable_data <float >();
214
- int64_t input_num = ShapeProduction (input_shape);
215
- if (input_data_path.empty ()) {
216
- for (int i = 0 ; i < input_num; ++i) {
217
- input_data[i] = 1 .f ;
218
- }
219
- } else {
220
- std::fstream fs (input_data_path);
221
- if (!fs.is_open ()) {
222
- LOG (FATAL) << " open input image " << input_data_path << " error." ;
223
- }
224
- for (int i = 0 ; i < input_num; i++) {
225
- fs >> input_data[i];
259
+ for (int i = 0 ; i < input_shapes.size (); i++) {
260
+ auto input_shape = input_shapes[i];
261
+ auto input_tensor = predictor->GetInput (i);
262
+ input_tensor->Resize (input_shape);
263
+ auto input_data = input_tensor->mutable_data <float >();
264
+ int64_t input_num = ShapeProduction (input_shape);
265
+ if (input_data_path.empty ()) {
266
+ for (int j = 0 ; j < input_num; ++j) {
267
+ input_data[j] = 1 .f ;
268
+ }
269
+ } else {
270
+ for (int j = 0 ; j < input_num; j++) {
271
+ std::fstream fs (input_data_path[j]);
272
+ if (!fs.is_open ()) {
273
+ LOG (FATAL) << " open input image " << input_data_path[j] << " error." ;
274
+ }
275
+ fs >> input_data[j];
276
+ }
226
277
}
227
278
}
228
279
@@ -272,8 +323,8 @@ void Run(const std::string& model_path,
272
323
LOG (INFO) << " model_name: " << model_name;
273
324
LOG (INFO) << " threads: " << threads;
274
325
LOG (INFO) << " power_mode: " << power_mode;
275
- LOG (INFO) << " input_data_path: " << input_data_path ;
276
- LOG (INFO) << " input_shape: " << Vector2Str (input_shape );
326
+ LOG (INFO) << " input_data_path: " << FLAGS_input_data_path ;
327
+ LOG (INFO) << " input_shape: " << Vector2Str (input_shapes );
277
328
LOG (INFO) << " warmup: " << warmup;
278
329
LOG (INFO) << " repeats: " << repeats;
279
330
LOG (INFO) << " result_path: " << result_path;
@@ -296,8 +347,14 @@ int main(int argc, char** argv) {
296
347
}
297
348
298
349
// Get input shape
299
- std::vector<int64_t > input_shape =
300
- paddle::lite_api::GetInputShape (FLAGS_input_shape);
350
+ std::string raw_input_shapes = FLAGS_input_shape;
351
+ std::cout << " raw_input_shapes: " << raw_input_shapes << std::endl;
352
+ std::vector<std::string> str_input_shapes =
353
+ paddle::lite_api::split_string (raw_input_shapes);
354
+ std::vector<std::vector<int64_t >> input_shapes;
355
+ for (size_t i = 0 ; i < str_input_shapes.size (); ++i) {
356
+ input_shapes.push_back (paddle::lite_api::get_shape (str_input_shapes[i]));
357
+ }
301
358
302
359
// Get model_name and run_model_path
303
360
std::string model_name;
@@ -320,7 +377,7 @@ int main(int argc, char** argv) {
320
377
}
321
378
322
379
// Run test
323
- paddle::lite_api::Run (run_model_path, model_name, input_shape );
380
+ paddle::lite_api::Run (run_model_path, model_name, input_shapes );
324
381
325
382
return 0 ;
326
383
}
0 commit comments