23
23
#include < executorch/extension/data_loader/buffer_data_loader.h>
24
24
#include < executorch/extension/data_loader/mmap_data_loader.h>
25
25
#include < executorch/extension/memory_allocator/malloc_memory_allocator.h>
26
+ #include < executorch/extension/module/bundled_module.h>
26
27
#include < executorch/extension/threadpool/threadpool.h>
27
28
#include < executorch/runtime/backend/interface.h>
28
29
#include < executorch/runtime/core/data_loader.h>
@@ -81,6 +82,7 @@ using ::executorch::ET_RUNTIME_NAMESPACE::Program;
81
82
using ::executorch::extension::BufferDataLoader;
82
83
using ::executorch::extension::MallocMemoryAllocator;
83
84
using ::executorch::extension::MmapDataLoader;
85
+ using ::executorch::extension::ET_BUNDLED_MODULE_NAMESPACE::BundledModule;
84
86
using ::executorch::runtime::ArrayRef;
85
87
using ::executorch::runtime::DataLoader;
86
88
using ::executorch::runtime::Error;
@@ -425,13 +427,54 @@ inline std::unique_ptr<Module> load_module_from_file(
425
427
program_verification);
426
428
}
427
429
430
+ inline py::list get_outputs_as_py_list (
431
+ const std::vector<EValue>& outputs,
432
+ bool clone_outputs = true ) {
433
+ const auto outputs_size = outputs.size ();
434
+ py::list list (outputs_size);
435
+ for (size_t i = 0 ; i < outputs_size; ++i) {
436
+ auto & v = outputs[i];
437
+ if (Tag::None == v.tag ) {
438
+ list[i] = py::none ();
439
+ } else if (Tag::Int == v.tag ) {
440
+ list[i] = py::cast (v.toInt ());
441
+ } else if (Tag::Double == v.tag ) {
442
+ list[i] = py::cast (v.toDouble ());
443
+ } else if (Tag::Bool == v.tag ) {
444
+ list[i] = py::cast (v.toBool ());
445
+ } else if (Tag::String == v.tag ) {
446
+ list[i] = py::cast (std::string (v.toString ().data ()));
447
+ } else if (Tag::Tensor == v.tag ) {
448
+ #ifdef USE_ATEN_LIB
449
+ // Clone so the outputs in python do not share a lifetime with the
450
+ // module object
451
+ if (clone_outputs) {
452
+ list[i] = py::cast (v.toTensor ().clone ());
453
+ } else {
454
+ list[i] = py::cast (v.toTensor ());
455
+ }
456
+ #else
457
+ if (clone_outputs) {
458
+ list[i] = py::cast (alias_attensor_to_etensor (v.toTensor ()).clone ());
459
+ } else {
460
+ list[i] = py::cast (alias_attensor_to_etensor (v.toTensor ()));
461
+ }
462
+ #endif
463
+ } else {
464
+ ET_ASSERT_UNREACHABLE_MSG (" Invalid model output type" );
465
+ }
466
+ }
467
+ return list;
468
+ }
469
+
428
470
static constexpr size_t kDEFAULT_BUNDLED_INPUT_POOL_SIZE = 16 * 1024U ;
429
471
430
- struct PyBundledModule final {
472
+ struct PyBundledModule : public BundledModule {
431
473
explicit PyBundledModule (
432
474
const py::bytes& buffer,
433
475
uint32_t bundled_input_pool_size)
434
- : bundled_program_ptr_(buffer),
476
+ : BundledModule(buffer.cast<std::string_view>().data()),
477
+ bundled_program_ptr_(buffer),
435
478
program_ptr_(static_cast <const void *>(
436
479
bundled_program_flatbuffer::GetBundledProgram (
437
480
get_bundled_program_ptr ())
@@ -460,6 +503,33 @@ struct PyBundledModule final {
460
503
return program_len_;
461
504
}
462
505
506
+ py::list verify_result_with_bundled_expected_output (
507
+ const std::string& method_name,
508
+ size_t testset_idx,
509
+ double rtol = 1e-5 ,
510
+ double atol = 1e-8 ) {
511
+ // Execute the method
512
+ auto result = BundledModule::execute (method_name, testset_idx);
513
+ if (!result.ok ()) {
514
+ THROW_IF_ERROR (
515
+ result.error (),
516
+ " Method execution failed with status 0x%" PRIx32,
517
+ static_cast <uint32_t >(result.error ()));
518
+ }
519
+
520
+ // Convert outputs to py::list
521
+ const auto & outputs = result.get ();
522
+ py::list py_outputs = get_outputs_as_py_list (outputs);
523
+
524
+ Error status = BundledModule::verify_method_outputs (
525
+ method_name, testset_idx, rtol, atol);
526
+ THROW_IF_ERROR (
527
+ status,
528
+ " Result verification failed with status %" PRIu32,
529
+ static_cast <uint32_t >(status));
530
+ return py_outputs;
531
+ }
532
+
463
533
private:
464
534
// Store the bytes object instead of a raw pointer so that this module will
465
535
// keep the bytes alive.
@@ -853,43 +923,6 @@ struct PyModule final {
853
923
}
854
924
}
855
925
856
- void load_bundled_input (
857
- PyBundledModule& m,
858
- const std::string method_name,
859
- size_t testset_idx) {
860
- const void * bundled_program_ptr = m.get_bundled_program_ptr ();
861
- Error status = executorch::BUNDLED_PROGRAM_NAMESPACE::load_bundled_input (
862
- module_->get_method (method_name), bundled_program_ptr, testset_idx);
863
- THROW_IF_ERROR (
864
- status,
865
- " load_bundled_input failed with status 0x%" PRIx32,
866
- static_cast <uint32_t >(status));
867
- }
868
-
869
- py::list verify_result_with_bundled_expected_output (
870
- PyBundledModule& m,
871
- const std::string method_name,
872
- size_t testset_idx,
873
- double rtol = 1e-5 ,
874
- double atol = 1e-8 ) {
875
- const void * bundled_program_ptr = m.get_bundled_program_ptr ();
876
- auto & method = module_->get_method (method_name);
877
- Error status = executorch::BUNDLED_PROGRAM_NAMESPACE::load_bundled_input (
878
- method, bundled_program_ptr, testset_idx);
879
- THROW_IF_ERROR (
880
- status,
881
- " load_bundled_input failed with status 0x%" PRIx32,
882
- static_cast <uint32_t >(status));
883
- py::list outputs = plan_execute (method_name);
884
- status = executorch::BUNDLED_PROGRAM_NAMESPACE::verify_method_outputs (
885
- method, bundled_program_ptr, testset_idx, rtol, atol);
886
- THROW_IF_ERROR (
887
- status,
888
- " Result verification failed with status %" PRIu32,
889
- static_cast <uint32_t >(status));
890
- return outputs;
891
- }
892
-
893
926
py::list plan_execute (
894
927
const std::string method_name,
895
928
bool clone_outputs = true ) {
@@ -912,46 +945,6 @@ struct PyModule final {
912
945
return get_outputs_as_py_list (outputs, clone_outputs);
913
946
}
914
947
915
- py::list get_outputs_as_py_list (
916
- const std::vector<EValue>& outputs,
917
- bool clone_outputs = true ) {
918
- const auto outputs_size = outputs.size ();
919
- py::list list (outputs_size);
920
- for (size_t i = 0 ; i < outputs_size; ++i) {
921
- auto & v = outputs[i];
922
- if (Tag::None == v.tag ) {
923
- list[i] = py::none ();
924
- } else if (Tag::Int == v.tag ) {
925
- list[i] = py::cast (v.toInt ());
926
- } else if (Tag::Double == v.tag ) {
927
- list[i] = py::cast (v.toDouble ());
928
- } else if (Tag::Bool == v.tag ) {
929
- list[i] = py::cast (v.toBool ());
930
- } else if (Tag::String == v.tag ) {
931
- list[i] = py::cast (std::string (v.toString ().data ()));
932
- } else if (Tag::Tensor == v.tag ) {
933
- #ifdef USE_ATEN_LIB
934
- // Clone so the outputs in python do not share a lifetime with the
935
- // module object
936
- if (clone_outputs) {
937
- list[i] = py::cast (v.toTensor ().clone ());
938
- } else {
939
- list[i] = py::cast (v.toTensor ());
940
- }
941
- #else
942
- if (clone_outputs) {
943
- list[i] = py::cast (alias_attensor_to_etensor (v.toTensor ()).clone ());
944
- } else {
945
- list[i] = py::cast (alias_attensor_to_etensor (v.toTensor ()));
946
- }
947
- #endif
948
- } else {
949
- ET_ASSERT_UNREACHABLE_MSG (" Invalid model output type" );
950
- }
951
- }
952
- return list;
953
- }
954
-
955
948
std::unique_ptr<PyMethodMeta> method_meta (const std::string method_name) {
956
949
auto & method = module_->get_method (method_name);
957
950
return std::make_unique<PyMethodMeta>(module_, method.method_meta ());
@@ -1583,16 +1576,6 @@ PYBIND11_MODULE(EXECUTORCH_PYTHON_MODULE_NAME, m) {
1583
1576
call_guard);
1584
1577
1585
1578
py::class_<PyModule>(m, " ExecuTorchModule" )
1586
- .def (" load_bundled_input" , &PyModule::load_bundled_input, call_guard)
1587
- .def (
1588
- " verify_result_with_bundled_expected_output" ,
1589
- &PyModule::verify_result_with_bundled_expected_output,
1590
- py::arg (" bundle" ),
1591
- py::arg (" method_name" ),
1592
- py::arg (" testset_idx" ),
1593
- py::arg (" rtol" ) = 1e-5 ,
1594
- py::arg (" atol" ) = 1e-8 ,
1595
- call_guard)
1596
1579
.def (
1597
1580
" plan_execute" ,
1598
1581
&PyModule::plan_execute,
@@ -1638,7 +1621,16 @@ PYBIND11_MODULE(EXECUTORCH_PYTHON_MODULE_NAME, m) {
1638
1621
py::arg (" clone_outputs" ) = true ,
1639
1622
call_guard);
1640
1623
1641
- py::class_<PyBundledModule>(m, " BundledModule" );
1624
+ py::class_<PyBundledModule>(m, " BundledModule" )
1625
+ .def (
1626
+ " verify_result_with_bundled_expected_output" ,
1627
+ &PyBundledModule::verify_result_with_bundled_expected_output,
1628
+ py::arg (" method_name" ),
1629
+ py::arg (" testset_idx" ),
1630
+ py::arg (" rtol" ) = 1e-5 ,
1631
+ py::arg (" atol" ) = 1e-8 ,
1632
+ call_guard);
1633
+
1642
1634
py::class_<PyTensorInfo>(m, " TensorInfo" )
1643
1635
.def (" sizes" , &PyTensorInfo::sizes, call_guard)
1644
1636
.def (" dtype" , &PyTensorInfo::dtype, call_guard)
0 commit comments