Skip to content

Commit 51fcbb6

Browse files
authored
benchmark support many inputs (#6495)
1 parent 2418711 commit 51fcbb6

File tree

1 file changed

+83
-26
lines changed

1 file changed

+83
-26
lines changed

lite/api/benchmark.cc

Lines changed: 83 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -84,12 +84,17 @@ inline double GetCurrentUS() {
8484
}
8585

8686
template <class T>
87-
std::string Vector2Str(const std::vector<T>& input) {
87+
std::string Vector2Str(const std::vector<std::vector<T>>& inputs) {
8888
std::stringstream ss;
89-
for (int i = 0; i < input.size() - 1; i++) {
90-
ss << input[i] << ",";
89+
for (int j = 0; j < inputs.size(); j++) {
90+
auto input = inputs[j];
91+
for (int i = 0; i < input.size() - 1; i++) {
92+
ss << input[i] << ",";
93+
}
94+
ss << input.back();
95+
ss << ";";
9196
}
92-
ss << input.back();
97+
9398
return ss.str();
9499
}
95100

@@ -102,6 +107,22 @@ T ShapeProduction(const std::vector<T>& shape) {
102107
return num;
103108
}
104109

110+
std::vector<int64_t> get_shape(const std::string& str_shape) {
111+
std::vector<int64_t> shape;
112+
std::string tmp_str = str_shape;
113+
while (!tmp_str.empty()) {
114+
int dim = atoi(tmp_str.data());
115+
shape.push_back(dim);
116+
size_t next_offset = tmp_str.find(",");
117+
if (next_offset == std::string::npos) {
118+
break;
119+
} else {
120+
tmp_str = tmp_str.substr(next_offset + 1);
121+
}
122+
}
123+
return shape;
124+
}
125+
105126
std::vector<int64_t> GetInputShape(const std::string& str_shape) {
106127
std::vector<int64_t> shape;
107128
std::string tmp_str = str_shape;
@@ -118,6 +139,21 @@ std::vector<int64_t> GetInputShape(const std::string& str_shape) {
118139
return shape;
119140
}
120141

142+
std::vector<std::string> split_string(const std::string& str_in) {
143+
std::vector<std::string> str_out;
144+
std::string tmp_str = str_in;
145+
while (!tmp_str.empty()) {
146+
size_t next_offset = tmp_str.find(":");
147+
str_out.push_back(tmp_str.substr(0, next_offset));
148+
if (next_offset == std::string::npos) {
149+
break;
150+
} else {
151+
tmp_str = tmp_str.substr(next_offset + 1);
152+
}
153+
}
154+
return str_out;
155+
}
156+
121157
void PrintUsage() {
122158
std::string help_info =
123159
"Usage: \n"
@@ -190,10 +226,22 @@ void OutputOptModel(const std::string& save_optimized_model_dir) {
190226

191227
void Run(const std::string& model_path,
192228
const std::string& model_name,
193-
const std::vector<int64_t>& input_shape) {
229+
const std::vector<std::vector<int64_t>>& input_shapes) {
194230
int threads = FLAGS_threads;
195231
int power_mode = FLAGS_power_mode;
196-
std::string input_data_path = FLAGS_input_data_path;
232+
std::vector<std::string> input_data_path =
233+
paddle::lite_api::split_string(FLAGS_input_data_path);
234+
if (!input_data_path.empty() &&
235+
input_data_path.size() != input_shapes.size()) {
236+
LOG(FATAL) << "Error: the input_data_path's size is not consistent with "
237+
"input_shape's size , the input_data_path contains "
238+
<< input_data_path.size()
239+
<< " input data path, while input_shape contains "
240+
<< input_shapes.size()
241+
<< " input shape. (members in input_data_path and input_shape "
242+
"sholud be separated by :)";
243+
}
244+
197245
int warmup = FLAGS_warmup;
198246
int repeats = FLAGS_repeats;
199247
std::string result_path = FLAGS_result_path;
@@ -208,21 +256,24 @@ void Run(const std::string& model_path,
208256
auto predictor = lite_api::CreatePaddlePredictor(config);
209257

210258
// set input
211-
auto input_tensor = predictor->GetInput(0);
212-
input_tensor->Resize(input_shape);
213-
auto input_data = input_tensor->mutable_data<float>();
214-
int64_t input_num = ShapeProduction(input_shape);
215-
if (input_data_path.empty()) {
216-
for (int i = 0; i < input_num; ++i) {
217-
input_data[i] = 1.f;
218-
}
219-
} else {
220-
std::fstream fs(input_data_path);
221-
if (!fs.is_open()) {
222-
LOG(FATAL) << "open input image " << input_data_path << " error.";
223-
}
224-
for (int i = 0; i < input_num; i++) {
225-
fs >> input_data[i];
259+
for (int i = 0; i < input_shapes.size(); i++) {
260+
auto input_shape = input_shapes[i];
261+
auto input_tensor = predictor->GetInput(i);
262+
input_tensor->Resize(input_shape);
263+
auto input_data = input_tensor->mutable_data<float>();
264+
int64_t input_num = ShapeProduction(input_shape);
265+
if (input_data_path.empty()) {
266+
for (int j = 0; j < input_num; ++j) {
267+
input_data[j] = 1.f;
268+
}
269+
} else {
270+
for (int j = 0; j < input_num; j++) {
271+
std::fstream fs(input_data_path[j]);
272+
if (!fs.is_open()) {
273+
LOG(FATAL) << "open input image " << input_data_path[j] << " error.";
274+
}
275+
fs >> input_data[j];
276+
}
226277
}
227278
}
228279

@@ -272,8 +323,8 @@ void Run(const std::string& model_path,
272323
LOG(INFO) << "model_name: " << model_name;
273324
LOG(INFO) << "threads: " << threads;
274325
LOG(INFO) << "power_mode: " << power_mode;
275-
LOG(INFO) << "input_data_path: " << input_data_path;
276-
LOG(INFO) << "input_shape: " << Vector2Str(input_shape);
326+
LOG(INFO) << "input_data_path: " << FLAGS_input_data_path;
327+
LOG(INFO) << "input_shape: " << Vector2Str(input_shapes);
277328
LOG(INFO) << "warmup: " << warmup;
278329
LOG(INFO) << "repeats: " << repeats;
279330
LOG(INFO) << "result_path: " << result_path;
@@ -296,8 +347,14 @@ int main(int argc, char** argv) {
296347
}
297348

298349
// Get input shape
299-
std::vector<int64_t> input_shape =
300-
paddle::lite_api::GetInputShape(FLAGS_input_shape);
350+
std::string raw_input_shapes = FLAGS_input_shape;
351+
std::cout << "raw_input_shapes: " << raw_input_shapes << std::endl;
352+
std::vector<std::string> str_input_shapes =
353+
paddle::lite_api::split_string(raw_input_shapes);
354+
std::vector<std::vector<int64_t>> input_shapes;
355+
for (size_t i = 0; i < str_input_shapes.size(); ++i) {
356+
input_shapes.push_back(paddle::lite_api::get_shape(str_input_shapes[i]));
357+
}
301358

302359
// Get model_name and run_model_path
303360
std::string model_name;
@@ -320,7 +377,7 @@ int main(int argc, char** argv) {
320377
}
321378

322379
// Run test
323-
paddle::lite_api::Run(run_model_path, model_name, input_shape);
380+
paddle::lite_api::Run(run_model_path, model_name, input_shapes);
324381

325382
return 0;
326383
}

0 commit comments

Comments
 (0)