2323#include < executorch/extension/data_loader/buffer_data_loader.h>
2424#include < executorch/extension/data_loader/mmap_data_loader.h>
2525#include < executorch/extension/memory_allocator/malloc_memory_allocator.h>
26+ #include < executorch/extension/module/bundled_module.h>
2627#include < executorch/extension/threadpool/threadpool.h>
2728#include < executorch/runtime/backend/interface.h>
2829#include < executorch/runtime/core/data_loader.h>
@@ -425,13 +426,54 @@ inline std::unique_ptr<Module> load_module_from_file(
425426 program_verification);
426427}
427428
429+ inline py::list get_outputs_as_py_list (
430+ const std::vector<EValue>& outputs,
431+ bool clone_outputs = true ) {
432+ const auto outputs_size = outputs.size ();
433+ py::list list (outputs_size);
434+ for (size_t i = 0 ; i < outputs_size; ++i) {
435+ auto & v = outputs[i];
436+ if (Tag::None == v.tag ) {
437+ list[i] = py::none ();
438+ } else if (Tag::Int == v.tag ) {
439+ list[i] = py::cast (v.toInt ());
440+ } else if (Tag::Double == v.tag ) {
441+ list[i] = py::cast (v.toDouble ());
442+ } else if (Tag::Bool == v.tag ) {
443+ list[i] = py::cast (v.toBool ());
444+ } else if (Tag::String == v.tag ) {
445+ list[i] = py::cast (std::string (v.toString ().data ()));
446+ } else if (Tag::Tensor == v.tag ) {
447+ #ifdef USE_ATEN_LIB
448+ // Clone so the outputs in python do not share a lifetime with the
449+ // module object
450+ if (clone_outputs) {
451+ list[i] = py::cast (v.toTensor ().clone ());
452+ } else {
453+ list[i] = py::cast (v.toTensor ());
454+ }
455+ #else
456+ if (clone_outputs) {
457+ list[i] = py::cast (alias_attensor_to_etensor (v.toTensor ()).clone ());
458+ } else {
459+ list[i] = py::cast (alias_attensor_to_etensor (v.toTensor ()));
460+ }
461+ #endif
462+ } else {
463+ ET_ASSERT_UNREACHABLE_MSG (" Invalid model output type" );
464+ }
465+ }
466+ return list;
467+ }
468+
428469static constexpr size_t kDEFAULT_BUNDLED_INPUT_POOL_SIZE = 16 * 1024U ;
429470
430- struct PyBundledModule final {
471+ struct PyBundledModule : public BundledModule {
431472 explicit PyBundledModule (
432473 const py::bytes& buffer,
433474 uint32_t bundled_input_pool_size)
434- : bundled_program_ptr_(buffer),
475+ : BundledModule(buffer.cast<std::string_view>().data()),
476+ bundled_program_ptr_(buffer),
435477 program_ptr_(static_cast <const void *>(
436478 bundled_program_flatbuffer::GetBundledProgram (
437479 get_bundled_program_ptr ())
@@ -460,6 +502,33 @@ struct PyBundledModule final {
460502 return program_len_;
461503 }
462504
505+ py::list verify_result_with_bundled_expected_output (
506+ const std::string& method_name,
507+ size_t testset_idx,
508+ double rtol = 1e-5 ,
509+ double atol = 1e-8 ) {
510+ // Execute the method
511+ auto result = BundledModule::execute (method_name, testset_idx);
512+ if (!result.ok ()) {
513+ THROW_IF_ERROR (
514+ result.error (),
515+ " Method execution failed with status 0x%" PRIx32,
516+ static_cast <uint32_t >(result.error ()));
517+ }
518+
519+ // Convert outputs to py::list
520+ const auto & outputs = result.get ();
521+ py::list py_outputs = get_outputs_as_py_list (outputs);
522+
523+ Error status = BundledModule::verify_method_outputs (
524+ method_name, testset_idx, rtol, atol);
525+ THROW_IF_ERROR (
526+ status,
527+ " Result verification failed with status %" PRIu32,
528+ static_cast <uint32_t >(status));
529+ return py_outputs;
530+ }
531+
463532 private:
464533 // Store the bytes object instead of a raw pointer so that this module will
465534 // keep the bytes alive.
@@ -816,43 +885,6 @@ struct PyModule final {
816885 }
817886 }
818887
819- void load_bundled_input (
820- PyBundledModule& m,
821- const std::string method_name,
822- size_t testset_idx) {
823- const void * bundled_program_ptr = m.get_bundled_program_ptr ();
824- Error status = executorch::BUNDLED_PROGRAM_NAMESPACE::load_bundled_input (
825- module_->get_method (method_name), bundled_program_ptr, testset_idx);
826- THROW_IF_ERROR (
827- status,
828- " load_bundled_input failed with status 0x%" PRIx32,
829- static_cast <uint32_t >(status));
830- }
831-
832- py::list verify_result_with_bundled_expected_output (
833- PyBundledModule& m,
834- const std::string method_name,
835- size_t testset_idx,
836- double rtol = 1e-5 ,
837- double atol = 1e-8 ) {
838- const void * bundled_program_ptr = m.get_bundled_program_ptr ();
839- auto & method = module_->get_method (method_name);
840- Error status = executorch::BUNDLED_PROGRAM_NAMESPACE::load_bundled_input (
841- method, bundled_program_ptr, testset_idx);
842- THROW_IF_ERROR (
843- status,
844- " load_bundled_input failed with status 0x%" PRIx32,
845- static_cast <uint32_t >(status));
846- py::list outputs = plan_execute (method_name);
847- status = executorch::BUNDLED_PROGRAM_NAMESPACE::verify_method_outputs (
848- method, bundled_program_ptr, testset_idx, rtol, atol);
849- THROW_IF_ERROR (
850- status,
851- " Result verification failed with status %" PRIu32,
852- static_cast <uint32_t >(status));
853- return outputs;
854- }
855-
856888 py::list plan_execute (
857889 const std::string method_name,
858890 bool clone_outputs = true ) {
@@ -875,46 +907,6 @@ struct PyModule final {
875907 return get_outputs_as_py_list (outputs, clone_outputs);
876908 }
877909
878- py::list get_outputs_as_py_list (
879- const std::vector<EValue>& outputs,
880- bool clone_outputs = true ) {
881- const auto outputs_size = outputs.size ();
882- py::list list (outputs_size);
883- for (size_t i = 0 ; i < outputs_size; ++i) {
884- auto & v = outputs[i];
885- if (Tag::None == v.tag ) {
886- list[i] = py::none ();
887- } else if (Tag::Int == v.tag ) {
888- list[i] = py::cast (v.toInt ());
889- } else if (Tag::Double == v.tag ) {
890- list[i] = py::cast (v.toDouble ());
891- } else if (Tag::Bool == v.tag ) {
892- list[i] = py::cast (v.toBool ());
893- } else if (Tag::String == v.tag ) {
894- list[i] = py::cast (std::string (v.toString ().data ()));
895- } else if (Tag::Tensor == v.tag ) {
896- #ifdef USE_ATEN_LIB
897- // Clone so the outputs in python do not share a lifetime with the
898- // module object
899- if (clone_outputs) {
900- list[i] = py::cast (v.toTensor ().clone ());
901- } else {
902- list[i] = py::cast (v.toTensor ());
903- }
904- #else
905- if (clone_outputs) {
906- list[i] = py::cast (alias_attensor_to_etensor (v.toTensor ()).clone ());
907- } else {
908- list[i] = py::cast (alias_attensor_to_etensor (v.toTensor ()));
909- }
910- #endif
911- } else {
912- ET_ASSERT_UNREACHABLE_MSG (" Invalid model output type" );
913- }
914- }
915- return list;
916- }
917-
918910 std::unique_ptr<PyMethodMeta> method_meta (const std::string method_name) {
919911 auto & method = module_->get_method (method_name);
920912 return std::make_unique<PyMethodMeta>(module_, method.method_meta ());
@@ -1074,16 +1066,6 @@ PYBIND11_MODULE(EXECUTORCH_PYTHON_MODULE_NAME, m) {
10741066 call_guard);
10751067
10761068 py::class_<PyModule>(m, " ExecuTorchModule" )
1077- .def (" load_bundled_input" , &PyModule::load_bundled_input, call_guard)
1078- .def (
1079- " verify_result_with_bundled_expected_output" ,
1080- &PyModule::verify_result_with_bundled_expected_output,
1081- py::arg (" bundle" ),
1082- py::arg (" method_name" ),
1083- py::arg (" testset_idx" ),
1084- py::arg (" rtol" ) = 1e-5 ,
1085- py::arg (" atol" ) = 1e-8 ,
1086- call_guard)
10871069 .def (
10881070 " plan_execute" ,
10891071 &PyModule::plan_execute,
@@ -1129,7 +1111,16 @@ PYBIND11_MODULE(EXECUTORCH_PYTHON_MODULE_NAME, m) {
11291111 py::arg (" clone_outputs" ) = true ,
11301112 call_guard);
11311113
1132- py::class_<PyBundledModule>(m, " BundledModule" );
1114+ py::class_<PyBundledModule>(m, " BundledModule" )
1115+ .def (
1116+ " verify_result_with_bundled_expected_output" ,
1117+ &PyBundledModule::verify_result_with_bundled_expected_output,
1118+ py::arg (" method_name" ),
1119+ py::arg (" testset_idx" ),
1120+ py::arg (" rtol" ) = 1e-5 ,
1121+ py::arg (" atol" ) = 1e-8 ,
1122+ call_guard);
1123+
11331124 py::class_<PyTensorInfo>(m, " TensorInfo" )
11341125 .def (" sizes" , &PyTensorInfo::sizes, call_guard)
11351126 .def (" dtype" , &PyTensorInfo::dtype, call_guard)
0 commit comments