2323#include < executorch/extension/data_loader/buffer_data_loader.h>
2424#include < executorch/extension/data_loader/mmap_data_loader.h>
2525#include < executorch/extension/memory_allocator/malloc_memory_allocator.h>
26+ #include < executorch/extension/module/bundled_module.h>
2627#include < executorch/extension/threadpool/threadpool.h>
2728#include < executorch/runtime/backend/interface.h>
2829#include < executorch/runtime/core/data_loader.h>
@@ -81,6 +82,7 @@ using ::executorch::ET_RUNTIME_NAMESPACE::Program;
8182using ::executorch::extension::BufferDataLoader;
8283using ::executorch::extension::MallocMemoryAllocator;
8384using ::executorch::extension::MmapDataLoader;
85+ using ::executorch::extension::ET_BUNDLED_MODULE_NAMESPACE::BundledModule;
8486using ::executorch::runtime::ArrayRef;
8587using ::executorch::runtime::DataLoader;
8688using ::executorch::runtime::Error;
@@ -425,13 +427,54 @@ inline std::unique_ptr<Module> load_module_from_file(
425427 program_verification);
426428}
427429
430+ inline py::list get_outputs_as_py_list (
431+ const std::vector<EValue>& outputs,
432+ bool clone_outputs = true ) {
433+ const auto outputs_size = outputs.size ();
434+ py::list list (outputs_size);
435+ for (size_t i = 0 ; i < outputs_size; ++i) {
436+ auto & v = outputs[i];
437+ if (Tag::None == v.tag ) {
438+ list[i] = py::none ();
439+ } else if (Tag::Int == v.tag ) {
440+ list[i] = py::cast (v.toInt ());
441+ } else if (Tag::Double == v.tag ) {
442+ list[i] = py::cast (v.toDouble ());
443+ } else if (Tag::Bool == v.tag ) {
444+ list[i] = py::cast (v.toBool ());
445+ } else if (Tag::String == v.tag ) {
446+ list[i] = py::cast (std::string (v.toString ().data ()));
447+ } else if (Tag::Tensor == v.tag ) {
448+ #ifdef USE_ATEN_LIB
449+ // Clone so the outputs in python do not share a lifetime with the
450+ // module object
451+ if (clone_outputs) {
452+ list[i] = py::cast (v.toTensor ().clone ());
453+ } else {
454+ list[i] = py::cast (v.toTensor ());
455+ }
456+ #else
457+ if (clone_outputs) {
458+ list[i] = py::cast (alias_attensor_to_etensor (v.toTensor ()).clone ());
459+ } else {
460+ list[i] = py::cast (alias_attensor_to_etensor (v.toTensor ()));
461+ }
462+ #endif
463+ } else {
464+ ET_ASSERT_UNREACHABLE_MSG (" Invalid model output type" );
465+ }
466+ }
467+ return list;
468+ }
469+
428470static constexpr size_t kDEFAULT_BUNDLED_INPUT_POOL_SIZE = 16 * 1024U ;
429471
430- struct PyBundledModule final {
472+ struct PyBundledModule : public BundledModule {
431473 explicit PyBundledModule (
432474 const py::bytes& buffer,
433475 uint32_t bundled_input_pool_size)
434- : bundled_program_ptr_(buffer),
476+ : BundledModule(buffer.cast<std::string_view>().data()),
477+ bundled_program_ptr_(buffer),
435478 program_ptr_(static_cast <const void *>(
436479 bundled_program_flatbuffer::GetBundledProgram (
437480 get_bundled_program_ptr ())
@@ -460,6 +503,33 @@ struct PyBundledModule final {
460503 return program_len_;
461504 }
462505
506+ py::list verify_result_with_bundled_expected_output (
507+ const std::string& method_name,
508+ size_t testset_idx,
509+ double rtol = 1e-5 ,
510+ double atol = 1e-8 ) {
511+ // Execute the method
512+ auto result = BundledModule::execute (method_name, testset_idx);
513+ if (!result.ok ()) {
514+ THROW_IF_ERROR (
515+ result.error (),
516+ " Method execution failed with status 0x%" PRIx32,
517+ static_cast <uint32_t >(result.error ()));
518+ }
519+
520+ // Convert outputs to py::list
521+ const auto & outputs = result.get ();
522+ py::list py_outputs = get_outputs_as_py_list (outputs);
523+
524+ Error status = BundledModule::verify_method_outputs (
525+ method_name, testset_idx, rtol, atol);
526+ THROW_IF_ERROR (
527+ status,
528+ " Result verification failed with status %" PRIu32,
529+ static_cast <uint32_t >(status));
530+ return py_outputs;
531+ }
532+
463533 private:
464534 // Store the bytes object instead of a raw pointer so that this module will
465535 // keep the bytes alive.
@@ -853,43 +923,6 @@ struct PyModule final {
853923 }
854924 }
855925
856- void load_bundled_input (
857- PyBundledModule& m,
858- const std::string method_name,
859- size_t testset_idx) {
860- const void * bundled_program_ptr = m.get_bundled_program_ptr ();
861- Error status = executorch::BUNDLED_PROGRAM_NAMESPACE::load_bundled_input (
862- module_->get_method (method_name), bundled_program_ptr, testset_idx);
863- THROW_IF_ERROR (
864- status,
865- " load_bundled_input failed with status 0x%" PRIx32,
866- static_cast <uint32_t >(status));
867- }
868-
869- py::list verify_result_with_bundled_expected_output (
870- PyBundledModule& m,
871- const std::string method_name,
872- size_t testset_idx,
873- double rtol = 1e-5 ,
874- double atol = 1e-8 ) {
875- const void * bundled_program_ptr = m.get_bundled_program_ptr ();
876- auto & method = module_->get_method (method_name);
877- Error status = executorch::BUNDLED_PROGRAM_NAMESPACE::load_bundled_input (
878- method, bundled_program_ptr, testset_idx);
879- THROW_IF_ERROR (
880- status,
881- " load_bundled_input failed with status 0x%" PRIx32,
882- static_cast <uint32_t >(status));
883- py::list outputs = plan_execute (method_name);
884- status = executorch::BUNDLED_PROGRAM_NAMESPACE::verify_method_outputs (
885- method, bundled_program_ptr, testset_idx, rtol, atol);
886- THROW_IF_ERROR (
887- status,
888- " Result verification failed with status %" PRIu32,
889- static_cast <uint32_t >(status));
890- return outputs;
891- }
892-
893926 py::list plan_execute (
894927 const std::string method_name,
895928 bool clone_outputs = true ) {
@@ -912,46 +945,6 @@ struct PyModule final {
912945 return get_outputs_as_py_list (outputs, clone_outputs);
913946 }
914947
915- py::list get_outputs_as_py_list (
916- const std::vector<EValue>& outputs,
917- bool clone_outputs = true ) {
918- const auto outputs_size = outputs.size ();
919- py::list list (outputs_size);
920- for (size_t i = 0 ; i < outputs_size; ++i) {
921- auto & v = outputs[i];
922- if (Tag::None == v.tag ) {
923- list[i] = py::none ();
924- } else if (Tag::Int == v.tag ) {
925- list[i] = py::cast (v.toInt ());
926- } else if (Tag::Double == v.tag ) {
927- list[i] = py::cast (v.toDouble ());
928- } else if (Tag::Bool == v.tag ) {
929- list[i] = py::cast (v.toBool ());
930- } else if (Tag::String == v.tag ) {
931- list[i] = py::cast (std::string (v.toString ().data ()));
932- } else if (Tag::Tensor == v.tag ) {
933- #ifdef USE_ATEN_LIB
934- // Clone so the outputs in python do not share a lifetime with the
935- // module object
936- if (clone_outputs) {
937- list[i] = py::cast (v.toTensor ().clone ());
938- } else {
939- list[i] = py::cast (v.toTensor ());
940- }
941- #else
942- if (clone_outputs) {
943- list[i] = py::cast (alias_attensor_to_etensor (v.toTensor ()).clone ());
944- } else {
945- list[i] = py::cast (alias_attensor_to_etensor (v.toTensor ()));
946- }
947- #endif
948- } else {
949- ET_ASSERT_UNREACHABLE_MSG (" Invalid model output type" );
950- }
951- }
952- return list;
953- }
954-
955948 std::unique_ptr<PyMethodMeta> method_meta (const std::string method_name) {
956949 auto & method = module_->get_method (method_name);
957950 return std::make_unique<PyMethodMeta>(module_, method.method_meta ());
@@ -1583,16 +1576,6 @@ PYBIND11_MODULE(EXECUTORCH_PYTHON_MODULE_NAME, m) {
15831576 call_guard);
15841577
15851578 py::class_<PyModule>(m, " ExecuTorchModule" )
1586- .def (" load_bundled_input" , &PyModule::load_bundled_input, call_guard)
1587- .def (
1588- " verify_result_with_bundled_expected_output" ,
1589- &PyModule::verify_result_with_bundled_expected_output,
1590- py::arg (" bundle" ),
1591- py::arg (" method_name" ),
1592- py::arg (" testset_idx" ),
1593- py::arg (" rtol" ) = 1e-5 ,
1594- py::arg (" atol" ) = 1e-8 ,
1595- call_guard)
15961579 .def (
15971580 " plan_execute" ,
15981581 &PyModule::plan_execute,
@@ -1638,7 +1621,16 @@ PYBIND11_MODULE(EXECUTORCH_PYTHON_MODULE_NAME, m) {
16381621 py::arg (" clone_outputs" ) = true ,
16391622 call_guard);
16401623
1641- py::class_<PyBundledModule>(m, " BundledModule" );
1624+ py::class_<PyBundledModule>(m, " BundledModule" )
1625+ .def (
1626+ " verify_result_with_bundled_expected_output" ,
1627+ &PyBundledModule::verify_result_with_bundled_expected_output,
1628+ py::arg (" method_name" ),
1629+ py::arg (" testset_idx" ),
1630+ py::arg (" rtol" ) = 1e-5 ,
1631+ py::arg (" atol" ) = 1e-8 ,
1632+ call_guard);
1633+
16421634 py::class_<PyTensorInfo>(m, " TensorInfo" )
16431635 .def (" sizes" , &PyTensorInfo::sizes, call_guard)
16441636 .def (" dtype" , &PyTensorInfo::dtype, call_guard)
0 commit comments