2323#include < executorch/extension/data_loader/buffer_data_loader.h>
2424#include < executorch/extension/data_loader/mmap_data_loader.h>
2525#include < executorch/extension/memory_allocator/malloc_memory_allocator.h>
26- #include < executorch/extension/module/bundled_module.h>
2726#include < executorch/extension/threadpool/threadpool.h>
2827#include < executorch/runtime/backend/interface.h>
2928#include < executorch/runtime/core/data_loader.h>
@@ -97,7 +96,6 @@ using ::executorch::ET_RUNTIME_NAMESPACE::Program;
9796using ::executorch::extension::BufferDataLoader;
9897using ::executorch::extension::MallocMemoryAllocator;
9998using ::executorch::extension::MmapDataLoader;
100- using ::executorch::extension::ET_BUNDLED_MODULE_NAMESPACE::BundledModule;
10199using ::executorch::runtime::ArrayRef;
102100using ::executorch::runtime::DataLoader;
103101using ::executorch::runtime::Error;
@@ -442,54 +440,13 @@ inline std::unique_ptr<Module> load_module_from_file(
442440 program_verification);
443441}
444442
445- inline py::list get_outputs_as_py_list (
446- const std::vector<EValue>& outputs,
447- bool clone_outputs = true ) {
448- const auto outputs_size = outputs.size ();
449- py::list list (outputs_size);
450- for (size_t i = 0 ; i < outputs_size; ++i) {
451- auto & v = outputs[i];
452- if (Tag::None == v.tag ) {
453- list[i] = py::none ();
454- } else if (Tag::Int == v.tag ) {
455- list[i] = py::cast (v.toInt ());
456- } else if (Tag::Double == v.tag ) {
457- list[i] = py::cast (v.toDouble ());
458- } else if (Tag::Bool == v.tag ) {
459- list[i] = py::cast (v.toBool ());
460- } else if (Tag::String == v.tag ) {
461- list[i] = py::cast (std::string (v.toString ().data ()));
462- } else if (Tag::Tensor == v.tag ) {
463- #ifdef USE_ATEN_LIB
464- // Clone so the outputs in python do not share a lifetime with the
465- // module object
466- if (clone_outputs) {
467- list[i] = py::cast (v.toTensor ().clone ());
468- } else {
469- list[i] = py::cast (v.toTensor ());
470- }
471- #else
472- if (clone_outputs) {
473- list[i] = py::cast (alias_attensor_to_etensor (v.toTensor ()).clone ());
474- } else {
475- list[i] = py::cast (alias_attensor_to_etensor (v.toTensor ()));
476- }
477- #endif
478- } else {
479- ET_ASSERT_UNREACHABLE_MSG (" Invalid model output type" );
480- }
481- }
482- return list;
483- }
484-
485443static constexpr size_t kDEFAULT_BUNDLED_INPUT_POOL_SIZE = 16 * 1024U ;
486444
487- struct PyBundledModule : public BundledModule {
445+ struct PyBundledModule final {
488446 explicit PyBundledModule (
489447 const py::bytes& buffer,
490448 uint32_t bundled_input_pool_size)
491- : BundledModule(buffer.cast<std::string_view>().data()),
492- bundled_program_ptr_(buffer),
449+ : bundled_program_ptr_(buffer),
493450 program_ptr_(static_cast <const void *>(
494451 bundled_program_flatbuffer::GetBundledProgram (
495452 get_bundled_program_ptr ())
@@ -518,33 +475,6 @@ struct PyBundledModule : public BundledModule {
518475 return program_len_;
519476 }
520477
521- py::list verify_result_with_bundled_expected_output (
522- const std::string& method_name,
523- size_t testset_idx,
524- double rtol = 1e-5 ,
525- double atol = 1e-8 ) {
526- // Execute the method
527- auto result = BundledModule::execute (method_name, testset_idx);
528- if (!result.ok ()) {
529- THROW_IF_ERROR (
530- result.error (),
531- " Method execution failed with status 0x%" PRIx32,
532- static_cast <uint32_t >(result.error ()));
533- }
534-
535- // Convert outputs to py::list
536- const auto & outputs = result.get ();
537- py::list py_outputs = get_outputs_as_py_list (outputs);
538-
539- Error status = BundledModule::verify_method_outputs (
540- method_name, testset_idx, rtol, atol);
541- THROW_IF_ERROR (
542- status,
543- " Result verification failed with status %" PRIu32,
544- static_cast <uint32_t >(status));
545- return py_outputs;
546- }
547-
548478 private:
549479 // Store the bytes object instead of a raw pointer so that this module will
550480 // keep the bytes alive.
@@ -901,6 +831,43 @@ struct PyModule final {
901831 }
902832 }
903833
834+ void load_bundled_input (
835+ PyBundledModule& m,
836+ const std::string method_name,
837+ size_t testset_idx) {
838+ const void * bundled_program_ptr = m.get_bundled_program_ptr ();
839+ Error status = executorch::BUNDLED_PROGRAM_NAMESPACE::load_bundled_input (
840+ module_->get_method (method_name), bundled_program_ptr, testset_idx);
841+ THROW_IF_ERROR (
842+ status,
843+ " load_bundled_input failed with status 0x%" PRIx32,
844+ static_cast <uint32_t >(status));
845+ }
846+
847+ py::list verify_result_with_bundled_expected_output (
848+ PyBundledModule& m,
849+ const std::string method_name,
850+ size_t testset_idx,
851+ double rtol = 1e-5 ,
852+ double atol = 1e-8 ) {
853+ const void * bundled_program_ptr = m.get_bundled_program_ptr ();
854+ auto & method = module_->get_method (method_name);
855+ Error status = executorch::BUNDLED_PROGRAM_NAMESPACE::load_bundled_input (
856+ method, bundled_program_ptr, testset_idx);
857+ THROW_IF_ERROR (
858+ status,
859+ " load_bundled_input failed with status 0x%" PRIx32,
860+ static_cast <uint32_t >(status));
861+ py::list outputs = plan_execute (method_name);
862+ status = executorch::BUNDLED_PROGRAM_NAMESPACE::verify_method_outputs (
863+ method, bundled_program_ptr, testset_idx, rtol, atol);
864+ THROW_IF_ERROR (
865+ status,
866+ " Result verification failed with status %" PRIu32,
867+ static_cast <uint32_t >(status));
868+ return outputs;
869+ }
870+
904871 py::list plan_execute (
905872 const std::string method_name,
906873 bool clone_outputs = true ) {
@@ -923,6 +890,46 @@ struct PyModule final {
923890 return get_outputs_as_py_list (outputs, clone_outputs);
924891 }
925892
893+ py::list get_outputs_as_py_list (
894+ const std::vector<EValue>& outputs,
895+ bool clone_outputs = true ) {
896+ const auto outputs_size = outputs.size ();
897+ py::list list (outputs_size);
898+ for (size_t i = 0 ; i < outputs_size; ++i) {
899+ auto & v = outputs[i];
900+ if (Tag::None == v.tag ) {
901+ list[i] = py::none ();
902+ } else if (Tag::Int == v.tag ) {
903+ list[i] = py::cast (v.toInt ());
904+ } else if (Tag::Double == v.tag ) {
905+ list[i] = py::cast (v.toDouble ());
906+ } else if (Tag::Bool == v.tag ) {
907+ list[i] = py::cast (v.toBool ());
908+ } else if (Tag::String == v.tag ) {
909+ list[i] = py::cast (std::string (v.toString ().data ()));
910+ } else if (Tag::Tensor == v.tag ) {
911+ #ifdef USE_ATEN_LIB
912+ // Clone so the outputs in python do not share a lifetime with the
913+ // module object
914+ if (clone_outputs) {
915+ list[i] = py::cast (v.toTensor ().clone ());
916+ } else {
917+ list[i] = py::cast (v.toTensor ());
918+ }
919+ #else
920+ if (clone_outputs) {
921+ list[i] = py::cast (alias_attensor_to_etensor (v.toTensor ()).clone ());
922+ } else {
923+ list[i] = py::cast (alias_attensor_to_etensor (v.toTensor ()));
924+ }
925+ #endif
926+ } else {
927+ ET_ASSERT_UNREACHABLE_MSG (" Invalid model output type" );
928+ }
929+ }
930+ return list;
931+ }
932+
926933 std::unique_ptr<PyMethodMeta> method_meta (const std::string method_name) {
927934 auto & method = module_->get_method (method_name);
928935 return std::make_unique<PyMethodMeta>(module_, method.method_meta ());
@@ -1082,6 +1089,16 @@ PYBIND11_MODULE(EXECUTORCH_PYTHON_MODULE_NAME, m) {
10821089 call_guard);
10831090
10841091 py::class_<PyModule>(m, " ExecuTorchModule" )
1092+ .def (" load_bundled_input" , &PyModule::load_bundled_input, call_guard)
1093+ .def (
1094+ " verify_result_with_bundled_expected_output" ,
1095+ &PyModule::verify_result_with_bundled_expected_output,
1096+ py::arg (" bundle" ),
1097+ py::arg (" method_name" ),
1098+ py::arg (" testset_idx" ),
1099+ py::arg (" rtol" ) = 1e-5 ,
1100+ py::arg (" atol" ) = 1e-8 ,
1101+ call_guard)
10851102 .def (
10861103 " plan_execute" ,
10871104 &PyModule::plan_execute,
@@ -1127,16 +1144,7 @@ PYBIND11_MODULE(EXECUTORCH_PYTHON_MODULE_NAME, m) {
11271144 py::arg (" clone_outputs" ) = true ,
11281145 call_guard);
11291146
1130- py::class_<PyBundledModule>(m, " BundledModule" )
1131- .def (
1132- " verify_result_with_bundled_expected_output" ,
1133- &PyBundledModule::verify_result_with_bundled_expected_output,
1134- py::arg (" method_name" ),
1135- py::arg (" testset_idx" ),
1136- py::arg (" rtol" ) = 1e-5 ,
1137- py::arg (" atol" ) = 1e-8 ,
1138- call_guard);
1139-
1147+ py::class_<PyBundledModule>(m, " BundledModule" );
11401148 py::class_<PyTensorInfo>(m, " TensorInfo" )
11411149 .def (" sizes" , &PyTensorInfo::sizes, call_guard)
11421150 .def (" dtype" , &PyTensorInfo::dtype, call_guard)
0 commit comments