2323#include < executorch/extension/data_loader/buffer_data_loader.h>
2424#include < executorch/extension/data_loader/mmap_data_loader.h>
2525#include < executorch/extension/memory_allocator/malloc_memory_allocator.h>
26- #include < executorch/extension/module/bundled_module.h>
2726#include < executorch/extension/threadpool/threadpool.h>
2827#include < executorch/runtime/backend/interface.h>
2928#include < executorch/runtime/core/data_loader.h>
@@ -441,54 +440,13 @@ inline std::unique_ptr<Module> load_module_from_file(
441440 program_verification);
442441}
443442
444- inline py::list get_outputs_as_py_list (
445- const std::vector<EValue>& outputs,
446- bool clone_outputs = true ) {
447- const auto outputs_size = outputs.size ();
448- py::list list (outputs_size);
449- for (size_t i = 0 ; i < outputs_size; ++i) {
450- auto & v = outputs[i];
451- if (Tag::None == v.tag ) {
452- list[i] = py::none ();
453- } else if (Tag::Int == v.tag ) {
454- list[i] = py::cast (v.toInt ());
455- } else if (Tag::Double == v.tag ) {
456- list[i] = py::cast (v.toDouble ());
457- } else if (Tag::Bool == v.tag ) {
458- list[i] = py::cast (v.toBool ());
459- } else if (Tag::String == v.tag ) {
460- list[i] = py::cast (std::string (v.toString ().data ()));
461- } else if (Tag::Tensor == v.tag ) {
462- #ifdef USE_ATEN_LIB
463- // Clone so the outputs in python do not share a lifetime with the
464- // module object
465- if (clone_outputs) {
466- list[i] = py::cast (v.toTensor ().clone ());
467- } else {
468- list[i] = py::cast (v.toTensor ());
469- }
470- #else
471- if (clone_outputs) {
472- list[i] = py::cast (alias_attensor_to_etensor (v.toTensor ()).clone ());
473- } else {
474- list[i] = py::cast (alias_attensor_to_etensor (v.toTensor ()));
475- }
476- #endif
477- } else {
478- ET_ASSERT_UNREACHABLE_MSG (" Invalid model output type" );
479- }
480- }
481- return list;
482- }
483-
484443static constexpr size_t kDEFAULT_BUNDLED_INPUT_POOL_SIZE = 16 * 1024U ;
485444
486- struct PyBundledModule : public BundledModule {
445+ struct PyBundledModule final {
487446 explicit PyBundledModule (
488447 const py::bytes& buffer,
489448 uint32_t bundled_input_pool_size)
490- : BundledModule(buffer.cast<std::string_view>().data()),
491- bundled_program_ptr_(buffer),
449+ : bundled_program_ptr_(buffer),
492450 program_ptr_(static_cast <const void *>(
493451 bundled_program_flatbuffer::GetBundledProgram (
494452 get_bundled_program_ptr ())
@@ -517,32 +475,6 @@ struct PyBundledModule : public BundledModule {
517475 return program_len_;
518476 }
519477
520- py::list verify_result_with_bundled_expected_output (
521- const std::string& method_name,
522- size_t testset_idx,
523- double rtol = 1e-5 ,
524- double atol = 1e-8 ) {
525- // Execute the method
526- auto result = BundledModule::execute (method_name, testset_idx);
527- if (!result.ok ()) {
528- THROW_IF_ERROR (
529- result.error (),
530- " Method execution failed with status 0x%" PRIx32,
531- static_cast <uint32_t >(result.error ()));
532- }
533-
534- // Convert outputs to py::list
535- const auto & outputs = result.get ();
536- py::list py_outputs = get_outputs_as_py_list (outputs);
537-
538- Error status = BundledModule::verify_method_outputs (method_name, testset_idx, rtol, atol);
539- THROW_IF_ERROR (
540- status,
541- " Result verification failed with status %" PRIu32,
542- static_cast <uint32_t >(status));
543- return py_outputs;
544- }
545-
546478 private:
547479 // Store the bytes object instead of a raw pointer so that this module will
548480 // keep the bytes alive.
@@ -859,7 +791,7 @@ struct PyModule final {
859791 }
860792
861793 py::list forward_single_input (
862- const torch::Tensor& inputTensor,
794+ const torch::Tensor& inputTensor,
863795 bool clone_outputs = true ) {
864796 py::list py_list;
865797 py_list.append (py::cast (inputTensor));
@@ -899,6 +831,43 @@ const torch::Tensor& inputTensor,
899831 }
900832 }
901833
834+ void load_bundled_input (
835+ PyBundledModule& m,
836+ const std::string method_name,
837+ size_t testset_idx) {
838+ const void * bundled_program_ptr = m.get_bundled_program_ptr ();
839+ Error status = executorch::BUNDLED_PROGRAM_NAMESPACE::load_bundled_input (
840+ module_->get_method (method_name), bundled_program_ptr, testset_idx);
841+ THROW_IF_ERROR (
842+ status,
843+ " load_bundled_input failed with status 0x%" PRIx32,
844+ static_cast <uint32_t >(status));
845+ }
846+
847+ py::list verify_result_with_bundled_expected_output (
848+ PyBundledModule& m,
849+ const std::string method_name,
850+ size_t testset_idx,
851+ double rtol = 1e-5 ,
852+ double atol = 1e-8 ) {
853+ const void * bundled_program_ptr = m.get_bundled_program_ptr ();
854+ auto & method = module_->get_method (method_name);
855+ Error status = executorch::BUNDLED_PROGRAM_NAMESPACE::load_bundled_input (
856+ method, bundled_program_ptr, testset_idx);
857+ THROW_IF_ERROR (
858+ status,
859+ " load_bundled_input failed with status 0x%" PRIx32,
860+ static_cast <uint32_t >(status));
861+ py::list outputs = plan_execute (method_name);
862+ status = executorch::BUNDLED_PROGRAM_NAMESPACE::verify_method_outputs (
863+ method, bundled_program_ptr, testset_idx, rtol, atol);
864+ THROW_IF_ERROR (
865+ status,
866+ " Result verification failed with status %" PRIu32,
867+ static_cast <uint32_t >(status));
868+ return outputs;
869+ }
870+
902871 py::list plan_execute (
903872 const std::string method_name,
904873 bool clone_outputs = true ) {
@@ -921,6 +890,46 @@ const torch::Tensor& inputTensor,
921890 return get_outputs_as_py_list (outputs, clone_outputs);
922891 }
923892
893+ py::list get_outputs_as_py_list (
894+ const std::vector<EValue>& outputs,
895+ bool clone_outputs = true ) {
896+ const auto outputs_size = outputs.size ();
897+ py::list list (outputs_size);
898+ for (size_t i = 0 ; i < outputs_size; ++i) {
899+ auto & v = outputs[i];
900+ if (Tag::None == v.tag ) {
901+ list[i] = py::none ();
902+ } else if (Tag::Int == v.tag ) {
903+ list[i] = py::cast (v.toInt ());
904+ } else if (Tag::Double == v.tag ) {
905+ list[i] = py::cast (v.toDouble ());
906+ } else if (Tag::Bool == v.tag ) {
907+ list[i] = py::cast (v.toBool ());
908+ } else if (Tag::String == v.tag ) {
909+ list[i] = py::cast (std::string (v.toString ().data ()));
910+ } else if (Tag::Tensor == v.tag ) {
911+ #ifdef USE_ATEN_LIB
912+ // Clone so the outputs in python do not share a lifetime with the
913+ // module object
914+ if (clone_outputs) {
915+ list[i] = py::cast (v.toTensor ().clone ());
916+ } else {
917+ list[i] = py::cast (v.toTensor ());
918+ }
919+ #else
920+ if (clone_outputs) {
921+ list[i] = py::cast (alias_attensor_to_etensor (v.toTensor ()).clone ());
922+ } else {
923+ list[i] = py::cast (alias_attensor_to_etensor (v.toTensor ()));
924+ }
925+ #endif
926+ } else {
927+ ET_ASSERT_UNREACHABLE_MSG (" Invalid model output type" );
928+ }
929+ }
930+ return list;
931+ }
932+
924933 std::unique_ptr<PyMethodMeta> method_meta (const std::string method_name) {
925934 auto & method = module_->get_method (method_name);
926935 return std::make_unique<PyMethodMeta>(module_, method.method_meta ());
@@ -1080,6 +1089,16 @@ PYBIND11_MODULE(EXECUTORCH_PYTHON_MODULE_NAME, m) {
10801089 call_guard);
10811090
10821091 py::class_<PyModule>(m, " ExecuTorchModule" )
1092+ .def (" load_bundled_input" , &PyModule::load_bundled_input, call_guard)
1093+ .def (
1094+ " verify_result_with_bundled_expected_output" ,
1095+ &PyModule::verify_result_with_bundled_expected_output,
1096+ py::arg (" bundle" ),
1097+ py::arg (" method_name" ),
1098+ py::arg (" testset_idx" ),
1099+ py::arg (" rtol" ) = 1e-5 ,
1100+ py::arg (" atol" ) = 1e-8 ,
1101+ call_guard)
10831102 .def (
10841103 " plan_execute" ,
10851104 &PyModule::plan_execute,
@@ -1125,15 +1144,7 @@ PYBIND11_MODULE(EXECUTORCH_PYTHON_MODULE_NAME, m) {
11251144 py::arg (" clone_outputs" ) = true ,
11261145 call_guard);
11271146
1128- py::class_<PyBundledModule>(m, " BundledModule" ).def (
1129- " verify_result_with_bundled_expected_output" ,
1130- &PyBundledModule::verify_result_with_bundled_expected_output,
1131- py::arg (" method_name" ),
1132- py::arg (" testset_idx" ),
1133- py::arg (" rtol" ) = 1e-5 ,
1134- py::arg (" atol" ) = 1e-8 ,
1135- call_guard);
1136-
1147+ py::class_<PyBundledModule>(m, " BundledModule" );
11371148 py::class_<PyTensorInfo>(m, " TensorInfo" )
11381149 .def (" sizes" , &PyTensorInfo::sizes, call_guard)
11391150 .def (" dtype" , &PyTensorInfo::dtype, call_guard)
0 commit comments