@@ -52,6 +52,11 @@ DEFINE_string(
5252 " model.pte" ,
5353 " Model serialized in flatbuffer format." );
5454DEFINE_uint32 (num_executions, 1 , " Number of times to run the model." );
55+ DEFINE_string (input_list_path, " input_list.txt" , " Model input list path." );
56+ DEFINE_string (
57+ output_folder_path,
58+ " outputs" ,
59+ " Executorch inference data output path." );
5560#ifdef ET_EVENT_TRACER_ENABLED
5661DEFINE_string (etdump_path, " model.etdump" , " Write ETDump data to this path." );
5762#endif // ET_EVENT_TRACER_ENABLED
@@ -271,57 +276,143 @@ int main(int argc, char** argv) {
271276 // because inputs whose space gets reused by memory planning (if
272277 // any such inputs exist) will not be preserved for the next
273278 // execution.
274-
275- ET_CHECK_MSG (
279+ std::ifstream input_list (FLAGS_input_list_path);
280+ if (input_list.is_open ()) {
281+ size_t num_inputs = method->inputs_size ();
282+ ET_LOG (Info, " Number of inputs: %zu" , num_inputs);
283+
284+ auto split = [](std::string s, std::string delimiter) {
285+ size_t pos_start = 0 , pos_end, delim_len = delimiter.length ();
286+ std::string token;
287+ std::vector<std::string> res;
288+
289+ while ((pos_end = s.find (delimiter, pos_start)) != std::string::npos) {
290+ token = s.substr (pos_start, pos_end - pos_start);
291+ pos_start = pos_end + delim_len;
292+ res.push_back (token);
293+ }
294+ res.push_back (s.substr (pos_start));
295+ return res;
296+ };
297+
298+ std::string file_path;
299+ int inference_index = 0 ;
300+ double elapsed_time = 0 ;
301+ while (std::getline (input_list, file_path)) {
302+ auto input_files = split (file_path, " " );
303+ if (input_files.size () == 0 ) {
304+ break ;
305+ }
306+ ET_CHECK_MSG (
307+ input_files.size () == num_inputs,
308+ " Number of inputs (%zu) mismatch with input files (%zu)" ,
309+ num_inputs,
310+ input_files.size ());
311+
312+ std::vector<std::vector<char >> input_buf (num_inputs);
313+ for (int input_index = 0 ; input_index < num_inputs; ++input_index) {
314+ MethodMeta method_meta = method->method_meta ();
315+ Result<executorch::runtime::TensorInfo> tensor_meta =
316+ method_meta.input_tensor_meta (input_index);
317+
318+ std::ifstream fin (input_files[input_index], std::ios::binary);
319+ fin.seekg (0 , fin.end );
320+ size_t file_size = fin.tellg ();
321+
322+ input_buf[input_index].resize (file_size);
323+ fin.seekg (0 , fin.beg );
324+ fin.read (
325+ static_cast <char *>(input_buf[input_index].data ()),
326+ file_size);
327+ fin.close ();
328+
329+ ET_CHECK_MSG (
330+ file_size == tensor_meta->nbytes (),
331+ " Input(%d) size mismatch. file bytes: %zu, tensor bytes: %zu" ,
332+ input_index,
333+ file_size,
334+ tensor_meta->nbytes ());
335+
336+ auto impl = executorch::aten::TensorImpl (
337+ tensor_meta->scalar_type (),
338+ /* dim=*/ tensor_meta->sizes ().size (),
339+ const_cast <executorch::aten::TensorImpl::SizesType*>(tensor_meta->sizes ().data ()),
340+ input_buf[input_index].data (),
341+ const_cast <executorch::aten::TensorImpl::DimOrderType*>(
342+ tensor_meta->dim_order ().data ()));
343+ Error ret = method->set_input (executorch::aten::Tensor (&impl), input_index);
344+ ET_CHECK_MSG (
345+ ret == Error::Ok, " Failed to set input tensor: %d" , (int )ret);
346+ }
347+ Error status = method->execute ();
348+ std::vector<EValue> outputs (method->outputs_size ());
349+ status = method->get_outputs (outputs.data (), method->outputs_size ());
350+ ET_CHECK (status == Error::Ok);
351+ for (size_t output_index = 0 ; output_index < method->outputs_size ();
352+ output_index++) {
353+ auto output_tensor = outputs[output_index].toTensor ();
354+ size_t nbytes = output_tensor.nbytes ();
355+ auto output_file_name = FLAGS_output_folder_path + " /output_" +
356+ std::to_string (inference_index) + " _" +
357+ std::to_string (output_index) + " .raw" ;
358+ std::ofstream fout (output_file_name.c_str (), std::ios::binary);
359+ fout.write (output_tensor.const_data_ptr <char >(), nbytes);
360+ fout.close ();
361+ }
362+ ++inference_index;
363+ }
364+ } else {
365+ ET_CHECK_MSG (
276366 inputs.ok (),
277367 " Could not prepare inputs: 0x%" PRIx32,
278368 (uint32_t )inputs.error ());
279- ET_LOG (Debug, " Inputs prepared." );
280- auto before_exec = std::chrono::high_resolution_clock::now ();
281- Error status = method->execute ();
282- auto after_exec = std::chrono::high_resolution_clock::now ();
283- double interval_1st_infs =
284- std::chrono::duration_cast<std::chrono::microseconds>(
285- after_exec - before_exec)
286- .count () /
287- 1000.0 ;
288- ET_CHECK_MSG (
289- status == Error::Ok,
290- " Execution of method %s failed with status 0x%" PRIx32,
291- method_name,
292- (uint32_t )status);
293-
294- // Run the model.
295- before_exec = std::chrono::high_resolution_clock::now ();
296- for (uint32_t i = 0 ; i < FLAGS_num_executions; i++) {
297- status = method->execute ();
369+ ET_LOG (Debug, " Inputs prepared." );
370+
371+ auto before_exec = std::chrono::high_resolution_clock::now ();
372+ Error status = method->execute ();
373+ auto after_exec = std::chrono::high_resolution_clock::now ();
374+ double interval_1st_infs =
375+ std::chrono::duration_cast<std::chrono::microseconds>(
376+ after_exec - before_exec)
377+ .count () /
378+ 1000.0 ;
298379 ET_CHECK_MSG (
299380 status == Error::Ok,
300381 " Execution of method %s failed with status 0x%" PRIx32,
301382 method_name,
302383 (uint32_t )status);
303- }
304- after_exec = std::chrono::high_resolution_clock::now ();
305- double interval_infs = std::chrono::duration_cast<std::chrono::microseconds>(
306- after_exec - before_exec)
307- .count () /
308- 1000.0 / FLAGS_num_executions;
309-
310- if (FLAGS_dump_statistics) {
311- auto output_file_name = " statistics.txt" ;
312- std::ofstream fout (output_file_name);
313- fout << " load: " + std::to_string (interval_load)
314- << " \n 1st: " + std::to_string (interval_1st_infs)
315- << " \n avg: " + std::to_string (interval_infs) << std::endl;
316- fout.close ();
317- }
318- ET_LOG (Info, " Model executed successfully." );
319384
320- if (tracer.get_event_tracer ()) {
321- // Dump ETDump data containing profiling/debugging data to file specified in
322- // command line flag.
323- status = tracer.write_etdump_to_file ();
324- ET_CHECK_MSG (status == Error::Ok, " Failed to save ETDump file." );
385+ // Run the model.
386+ before_exec = std::chrono::high_resolution_clock::now ();
387+ for (uint32_t i = 0 ; i < FLAGS_num_executions; i++) {
388+ status = method->execute ();
389+ ET_CHECK_MSG (
390+ status == Error::Ok,
391+ " Execution of method %s failed with status 0x%" PRIx32,
392+ method_name,
393+ (uint32_t )status);
394+ }
395+ after_exec = std::chrono::high_resolution_clock::now ();
396+ double interval_infs = std::chrono::duration_cast<std::chrono::microseconds>(
397+ after_exec - before_exec)
398+ .count () /
399+ 1000.0 / FLAGS_num_executions;
400+
401+ if (FLAGS_dump_statistics) {
402+ auto output_file_name = " statistics.txt" ;
403+ std::ofstream fout (output_file_name);
404+ fout << " load: " + std::to_string (interval_load)
405+ << " \n 1st: " + std::to_string (interval_1st_infs)
406+ << " \n avg: " + std::to_string (interval_infs) << std::endl;
407+ fout.close ();
408+ }
409+ ET_LOG (Info, " Model executed successfully." );
410+ if (tracer.get_event_tracer ()) {
411+ // Dump ETDump data containing profiling/debugging data to file specified in
412+ // command line flag.
413+ status = tracer.write_etdump_to_file ();
414+ ET_CHECK_MSG (status == Error::Ok, " Failed to save ETDump file." );
415+ }
325416 }
326417
327418 return 0 ;
0 commit comments