2323#include < executorch/extension/data_loader/buffer_data_loader.h>
2424#include < executorch/extension/data_loader/mmap_data_loader.h>
2525#include < executorch/extension/memory_allocator/malloc_memory_allocator.h>
26+ #include < executorch/extension/module/bundled_module.h>
2627#include < executorch/extension/threadpool/threadpool.h>
2728#include < executorch/runtime/backend/interface.h>
2829#include < executorch/runtime/core/data_loader.h>
@@ -440,13 +441,54 @@ inline std::unique_ptr<Module> load_module_from_file(
440441 program_verification);
441442}
442443
444+ inline py::list get_outputs_as_py_list (
445+ const std::vector<EValue>& outputs,
446+ bool clone_outputs = true ) {
447+ const auto outputs_size = outputs.size ();
448+ py::list list (outputs_size);
449+ for (size_t i = 0 ; i < outputs_size; ++i) {
450+ auto & v = outputs[i];
451+ if (Tag::None == v.tag ) {
452+ list[i] = py::none ();
453+ } else if (Tag::Int == v.tag ) {
454+ list[i] = py::cast (v.toInt ());
455+ } else if (Tag::Double == v.tag ) {
456+ list[i] = py::cast (v.toDouble ());
457+ } else if (Tag::Bool == v.tag ) {
458+ list[i] = py::cast (v.toBool ());
459+ } else if (Tag::String == v.tag ) {
460+ list[i] = py::cast (std::string (v.toString ().data ()));
461+ } else if (Tag::Tensor == v.tag ) {
462+ #ifdef USE_ATEN_LIB
463+ // Clone so the outputs in python do not share a lifetime with the
464+ // module object
465+ if (clone_outputs) {
466+ list[i] = py::cast (v.toTensor ().clone ());
467+ } else {
468+ list[i] = py::cast (v.toTensor ());
469+ }
470+ #else
471+ if (clone_outputs) {
472+ list[i] = py::cast (alias_attensor_to_etensor (v.toTensor ()).clone ());
473+ } else {
474+ list[i] = py::cast (alias_attensor_to_etensor (v.toTensor ()));
475+ }
476+ #endif
477+ } else {
478+ ET_ASSERT_UNREACHABLE_MSG (" Invalid model output type" );
479+ }
480+ }
481+ return list;
482+ }
483+
443484static constexpr size_t kDEFAULT_BUNDLED_INPUT_POOL_SIZE = 16 * 1024U ;
444485
445- struct PyBundledModule final {
486+ struct PyBundledModule : public BundledModule {
446487 explicit PyBundledModule (
447488 const py::bytes& buffer,
448489 uint32_t bundled_input_pool_size)
449- : bundled_program_ptr_(buffer),
490+ : BundledModule(buffer.cast<std::string_view>().data()),
491+ bundled_program_ptr_(buffer),
450492 program_ptr_(static_cast <const void *>(
451493 bundled_program_flatbuffer::GetBundledProgram (
452494 get_bundled_program_ptr ())
@@ -475,6 +517,32 @@ struct PyBundledModule final {
475517 return program_len_;
476518 }
477519
520+ py::list verify_result_with_bundled_expected_output (
521+ const std::string& method_name,
522+ size_t testset_idx,
523+ double rtol = 1e-5 ,
524+ double atol = 1e-8 ) {
525+ // Execute the method
526+ auto result = BundledModule::execute (method_name, testset_idx);
527+ if (!result.ok ()) {
528+ THROW_IF_ERROR (
529+ result.error (),
530+ " Method execution failed with status 0x%" PRIx32,
531+ static_cast <uint32_t >(result.error ()));
532+ }
533+
534+ // Convert outputs to py::list
535+ const auto & outputs = result.get ();
536+ py::list py_outputs = get_outputs_as_py_list (outputs);
537+
538+ Error status = BundledModule::verify_method_outputs (method_name, testset_idx, rtol, atol);
539+ THROW_IF_ERROR (
540+ status,
541+ " Result verification failed with status %" PRIu32,
542+ static_cast <uint32_t >(status));
543+ return py_outputs;
544+ }
545+
478546 private:
479547 // Store the bytes object instead of a raw pointer so that this module will
480548 // keep the bytes alive.
@@ -791,7 +859,7 @@ struct PyModule final {
791859 }
792860
793861 py::list forward_single_input (
794- const torch::Tensor& inputTensor,
862+ const torch::Tensor& inputTensor,
795863 bool clone_outputs = true ) {
796864 py::list py_list;
797865 py_list.append (py::cast (inputTensor));
@@ -831,43 +899,6 @@ struct PyModule final {
831899 }
832900 }
833901
834- void load_bundled_input (
835- PyBundledModule& m,
836- const std::string method_name,
837- size_t testset_idx) {
838- const void * bundled_program_ptr = m.get_bundled_program_ptr ();
839- Error status = executorch::BUNDLED_PROGRAM_NAMESPACE::load_bundled_input (
840- module_->get_method (method_name), bundled_program_ptr, testset_idx);
841- THROW_IF_ERROR (
842- status,
843- " load_bundled_input failed with status 0x%" PRIx32,
844- static_cast <uint32_t >(status));
845- }
846-
847- py::list verify_result_with_bundled_expected_output (
848- PyBundledModule& m,
849- const std::string method_name,
850- size_t testset_idx,
851- double rtol = 1e-5 ,
852- double atol = 1e-8 ) {
853- const void * bundled_program_ptr = m.get_bundled_program_ptr ();
854- auto & method = module_->get_method (method_name);
855- Error status = executorch::BUNDLED_PROGRAM_NAMESPACE::load_bundled_input (
856- method, bundled_program_ptr, testset_idx);
857- THROW_IF_ERROR (
858- status,
859- " load_bundled_input failed with status 0x%" PRIx32,
860- static_cast <uint32_t >(status));
861- py::list outputs = plan_execute (method_name);
862- status = executorch::BUNDLED_PROGRAM_NAMESPACE::verify_method_outputs (
863- method, bundled_program_ptr, testset_idx, rtol, atol);
864- THROW_IF_ERROR (
865- status,
866- " Result verification failed with status %" PRIu32,
867- static_cast <uint32_t >(status));
868- return outputs;
869- }
870-
871902 py::list plan_execute (
872903 const std::string method_name,
873904 bool clone_outputs = true ) {
@@ -890,46 +921,6 @@ struct PyModule final {
890921 return get_outputs_as_py_list (outputs, clone_outputs);
891922 }
892923
893- py::list get_outputs_as_py_list (
894- const std::vector<EValue>& outputs,
895- bool clone_outputs = true ) {
896- const auto outputs_size = outputs.size ();
897- py::list list (outputs_size);
898- for (size_t i = 0 ; i < outputs_size; ++i) {
899- auto & v = outputs[i];
900- if (Tag::None == v.tag ) {
901- list[i] = py::none ();
902- } else if (Tag::Int == v.tag ) {
903- list[i] = py::cast (v.toInt ());
904- } else if (Tag::Double == v.tag ) {
905- list[i] = py::cast (v.toDouble ());
906- } else if (Tag::Bool == v.tag ) {
907- list[i] = py::cast (v.toBool ());
908- } else if (Tag::String == v.tag ) {
909- list[i] = py::cast (std::string (v.toString ().data ()));
910- } else if (Tag::Tensor == v.tag ) {
911- #ifdef USE_ATEN_LIB
912- // Clone so the outputs in python do not share a lifetime with the
913- // module object
914- if (clone_outputs) {
915- list[i] = py::cast (v.toTensor ().clone ());
916- } else {
917- list[i] = py::cast (v.toTensor ());
918- }
919- #else
920- if (clone_outputs) {
921- list[i] = py::cast (alias_attensor_to_etensor (v.toTensor ()).clone ());
922- } else {
923- list[i] = py::cast (alias_attensor_to_etensor (v.toTensor ()));
924- }
925- #endif
926- } else {
927- ET_ASSERT_UNREACHABLE_MSG (" Invalid model output type" );
928- }
929- }
930- return list;
931- }
932-
933924 std::unique_ptr<PyMethodMeta> method_meta (const std::string method_name) {
934925 auto & method = module_->get_method (method_name);
935926 return std::make_unique<PyMethodMeta>(module_, method.method_meta ());
@@ -1089,16 +1080,6 @@ PYBIND11_MODULE(EXECUTORCH_PYTHON_MODULE_NAME, m) {
10891080 call_guard);
10901081
10911082 py::class_<PyModule>(m, " ExecuTorchModule" )
1092- .def (" load_bundled_input" , &PyModule::load_bundled_input, call_guard)
1093- .def (
1094- " verify_result_with_bundled_expected_output" ,
1095- &PyModule::verify_result_with_bundled_expected_output,
1096- py::arg (" bundle" ),
1097- py::arg (" method_name" ),
1098- py::arg (" testset_idx" ),
1099- py::arg (" rtol" ) = 1e-5 ,
1100- py::arg (" atol" ) = 1e-8 ,
1101- call_guard)
11021083 .def (
11031084 " plan_execute" ,
11041085 &PyModule::plan_execute,
@@ -1144,7 +1125,15 @@ PYBIND11_MODULE(EXECUTORCH_PYTHON_MODULE_NAME, m) {
11441125 py::arg (" clone_outputs" ) = true ,
11451126 call_guard);
11461127
1147- py::class_<PyBundledModule>(m, " BundledModule" );
1128+ py::class_<PyBundledModule>(m, " BundledModule" ).def (
1129+ " verify_result_with_bundled_expected_output" ,
1130+ &PyBundledModule::verify_result_with_bundled_expected_output,
1131+ py::arg (" method_name" ),
1132+ py::arg (" testset_idx" ),
1133+ py::arg (" rtol" ) = 1e-5 ,
1134+ py::arg (" atol" ) = 1e-8 ,
1135+ call_guard);
1136+
11481137 py::class_<PyTensorInfo>(m, " TensorInfo" )
11491138 .def (" sizes" , &PyTensorInfo::sizes, call_guard)
11501139 .def (" dtype" , &PyTensorInfo::dtype, call_guard)
0 commit comments