2323#include < executorch/extension/data_loader/buffer_data_loader.h>
2424#include < executorch/extension/data_loader/mmap_data_loader.h>
2525#include < executorch/extension/memory_allocator/malloc_memory_allocator.h>
26- #include < executorch/extension/module/bundled_module.h>
2726#include < executorch/extension/threadpool/threadpool.h>
2827#include < executorch/runtime/backend/interface.h>
2928#include < executorch/runtime/core/data_loader.h>
@@ -442,54 +441,13 @@ inline std::unique_ptr<Module> load_module_from_file(
442441 program_verification);
443442}
444443
445- inline py::list get_outputs_as_py_list (
446- const std::vector<EValue>& outputs,
447- bool clone_outputs = true ) {
448- const auto outputs_size = outputs.size ();
449- py::list list (outputs_size);
450- for (size_t i = 0 ; i < outputs_size; ++i) {
451- auto & v = outputs[i];
452- if (Tag::None == v.tag ) {
453- list[i] = py::none ();
454- } else if (Tag::Int == v.tag ) {
455- list[i] = py::cast (v.toInt ());
456- } else if (Tag::Double == v.tag ) {
457- list[i] = py::cast (v.toDouble ());
458- } else if (Tag::Bool == v.tag ) {
459- list[i] = py::cast (v.toBool ());
460- } else if (Tag::String == v.tag ) {
461- list[i] = py::cast (std::string (v.toString ().data ()));
462- } else if (Tag::Tensor == v.tag ) {
463- #ifdef USE_ATEN_LIB
464- // Clone so the outputs in python do not share a lifetime with the
465- // module object
466- if (clone_outputs) {
467- list[i] = py::cast (v.toTensor ().clone ());
468- } else {
469- list[i] = py::cast (v.toTensor ());
470- }
471- #else
472- if (clone_outputs) {
473- list[i] = py::cast (alias_attensor_to_etensor (v.toTensor ()).clone ());
474- } else {
475- list[i] = py::cast (alias_attensor_to_etensor (v.toTensor ()));
476- }
477- #endif
478- } else {
479- ET_ASSERT_UNREACHABLE_MSG (" Invalid model output type" );
480- }
481- }
482- return list;
483- }
484-
485444static constexpr size_t kDEFAULT_BUNDLED_INPUT_POOL_SIZE = 16 * 1024U ;
486445
487- struct PyBundledModule : public BundledModule {
446+ struct PyBundledModule final {
488447 explicit PyBundledModule (
489448 const py::bytes& buffer,
490449 uint32_t bundled_input_pool_size)
491- : BundledModule(buffer.cast<std::string_view>().data()),
492- bundled_program_ptr_(buffer),
450+ : bundled_program_ptr_(buffer),
493451 program_ptr_(static_cast <const void *>(
494452 bundled_program_flatbuffer::GetBundledProgram (
495453 get_bundled_program_ptr ())
@@ -518,32 +476,6 @@ struct PyBundledModule : public BundledModule {
518476 return program_len_;
519477 }
520478
521- py::list verify_result_with_bundled_expected_output (
522- const std::string& method_name,
523- size_t testset_idx,
524- double rtol = 1e-5 ,
525- double atol = 1e-8 ) {
526- // Execute the method
527- auto result = BundledModule::execute (method_name, testset_idx);
528- if (!result.ok ()) {
529- THROW_IF_ERROR (
530- result.error (),
531- " Method execution failed with status 0x%" PRIx32,
532- static_cast <uint32_t >(result.error ()));
533- }
534-
535- // Convert outputs to py::list
536- const auto & outputs = result.get ();
537- py::list py_outputs = get_outputs_as_py_list (outputs);
538-
539- Error status = BundledModule::verify_method_outputs (method_name, testset_idx, rtol, atol);
540- THROW_IF_ERROR (
541- status,
542- " Result verification failed with status %" PRIu32,
543- static_cast <uint32_t >(status));
544- return py_outputs;
545- }
546-
547479 private:
548480 // Store the bytes object instead of a raw pointer so that this module will
549481 // keep the bytes alive.
@@ -860,7 +792,7 @@ struct PyModule final {
860792 }
861793
862794 py::list forward_single_input (
863- const torch::Tensor& inputTensor,
795+ const torch::Tensor& inputTensor,
864796 bool clone_outputs = true ) {
865797 py::list py_list;
866798 py_list.append (py::cast (inputTensor));
@@ -900,6 +832,43 @@ const torch::Tensor& inputTensor,
900832 }
901833 }
902834
835+ void load_bundled_input (
836+ PyBundledModule& m,
837+ const std::string method_name,
838+ size_t testset_idx) {
839+ const void * bundled_program_ptr = m.get_bundled_program_ptr ();
840+ Error status = executorch::BUNDLED_PROGRAM_NAMESPACE::load_bundled_input (
841+ module_->get_method (method_name), bundled_program_ptr, testset_idx);
842+ THROW_IF_ERROR (
843+ status,
844+ " load_bundled_input failed with status 0x%" PRIx32,
845+ static_cast <uint32_t >(status));
846+ }
847+
848+ py::list verify_result_with_bundled_expected_output (
849+ PyBundledModule& m,
850+ const std::string method_name,
851+ size_t testset_idx,
852+ double rtol = 1e-5 ,
853+ double atol = 1e-8 ) {
854+ const void * bundled_program_ptr = m.get_bundled_program_ptr ();
855+ auto & method = module_->get_method (method_name);
856+ Error status = executorch::BUNDLED_PROGRAM_NAMESPACE::load_bundled_input (
857+ method, bundled_program_ptr, testset_idx);
858+ THROW_IF_ERROR (
859+ status,
860+ " load_bundled_input failed with status 0x%" PRIx32,
861+ static_cast <uint32_t >(status));
862+ py::list outputs = plan_execute (method_name);
863+ status = executorch::BUNDLED_PROGRAM_NAMESPACE::verify_method_outputs (
864+ method, bundled_program_ptr, testset_idx, rtol, atol);
865+ THROW_IF_ERROR (
866+ status,
867+ " Result verification failed with status %" PRIu32,
868+ static_cast <uint32_t >(status));
869+ return outputs;
870+ }
871+
903872 py::list plan_execute (
904873 const std::string method_name,
905874 bool clone_outputs = true ) {
@@ -922,6 +891,46 @@ const torch::Tensor& inputTensor,
922891 return get_outputs_as_py_list (outputs, clone_outputs);
923892 }
924893
894+ py::list get_outputs_as_py_list (
895+ const std::vector<EValue>& outputs,
896+ bool clone_outputs = true ) {
897+ const auto outputs_size = outputs.size ();
898+ py::list list (outputs_size);
899+ for (size_t i = 0 ; i < outputs_size; ++i) {
900+ auto & v = outputs[i];
901+ if (Tag::None == v.tag ) {
902+ list[i] = py::none ();
903+ } else if (Tag::Int == v.tag ) {
904+ list[i] = py::cast (v.toInt ());
905+ } else if (Tag::Double == v.tag ) {
906+ list[i] = py::cast (v.toDouble ());
907+ } else if (Tag::Bool == v.tag ) {
908+ list[i] = py::cast (v.toBool ());
909+ } else if (Tag::String == v.tag ) {
910+ list[i] = py::cast (std::string (v.toString ().data ()));
911+ } else if (Tag::Tensor == v.tag ) {
912+ #ifdef USE_ATEN_LIB
913+ // Clone so the outputs in python do not share a lifetime with the
914+ // module object
915+ if (clone_outputs) {
916+ list[i] = py::cast (v.toTensor ().clone ());
917+ } else {
918+ list[i] = py::cast (v.toTensor ());
919+ }
920+ #else
921+ if (clone_outputs) {
922+ list[i] = py::cast (alias_attensor_to_etensor (v.toTensor ()).clone ());
923+ } else {
924+ list[i] = py::cast (alias_attensor_to_etensor (v.toTensor ()));
925+ }
926+ #endif
927+ } else {
928+ ET_ASSERT_UNREACHABLE_MSG (" Invalid model output type" );
929+ }
930+ }
931+ return list;
932+ }
933+
925934 std::unique_ptr<PyMethodMeta> method_meta (const std::string method_name) {
926935 auto & method = module_->get_method (method_name);
927936 return std::make_unique<PyMethodMeta>(module_, method.method_meta ());
@@ -1081,6 +1090,16 @@ PYBIND11_MODULE(EXECUTORCH_PYTHON_MODULE_NAME, m) {
10811090 call_guard);
10821091
10831092 py::class_<PyModule>(m, " ExecuTorchModule" )
1093+ .def (" load_bundled_input" , &PyModule::load_bundled_input, call_guard)
1094+ .def (
1095+ " verify_result_with_bundled_expected_output" ,
1096+ &PyModule::verify_result_with_bundled_expected_output,
1097+ py::arg (" bundle" ),
1098+ py::arg (" method_name" ),
1099+ py::arg (" testset_idx" ),
1100+ py::arg (" rtol" ) = 1e-5 ,
1101+ py::arg (" atol" ) = 1e-8 ,
1102+ call_guard)
10841103 .def (
10851104 " plan_execute" ,
10861105 &PyModule::plan_execute,
@@ -1126,15 +1145,7 @@ PYBIND11_MODULE(EXECUTORCH_PYTHON_MODULE_NAME, m) {
11261145 py::arg (" clone_outputs" ) = true ,
11271146 call_guard);
11281147
1129- py::class_<PyBundledModule>(m, " BundledModule" ).def (
1130- " verify_result_with_bundled_expected_output" ,
1131- &PyBundledModule::verify_result_with_bundled_expected_output,
1132- py::arg (" method_name" ),
1133- py::arg (" testset_idx" ),
1134- py::arg (" rtol" ) = 1e-5 ,
1135- py::arg (" atol" ) = 1e-8 ,
1136- call_guard);
1137-
1148+ py::class_<PyBundledModule>(m, " BundledModule" );
11381149 py::class_<PyTensorInfo>(m, " TensorInfo" )
11391150 .def (" sizes" , &PyTensorInfo::sizes, call_guard)
11401151 .def (" dtype" , &PyTensorInfo::dtype, call_guard)
0 commit comments