diff --git a/devtools/bundled_program/test/test_end2end.py b/devtools/bundled_program/test/test_end2end.py index 7cee073be0e..42f38adabf5 100644 --- a/devtools/bundled_program/test/test_end2end.py +++ b/devtools/bundled_program/test/test_end2end.py @@ -5,21 +5,7 @@ # LICENSE file in the root directory of this source tree. # flake8: noqa: F401 -import functools -import inspect -import os -import random import unittest -from typing import Callable, Dict, Optional, Tuple, Type - -import executorch.exir as exir - -import executorch.exir.control_flow as control_flow - -# @manual=//executorch/extension/pytree:pybindings -import executorch.extension.pytree as pytree - -import torch from executorch.devtools.bundled_program.core import BundledProgram from executorch.devtools.bundled_program.serialize import ( @@ -35,7 +21,6 @@ try: from executorch.extension.pybindings.portable_lib import ( _load_bundled_program_from_buffer, - _load_for_executorch_from_buffer, _load_for_executorch_from_bundled_program, ) @@ -47,7 +32,6 @@ try: from executorch.extension.pybindings.aten_lib import ( # @manual=//executorch/extension/pybindings:aten_lib _load_bundled_program_from_buffer, - _load_for_executorch_from_buffer, _load_for_executorch_from_bundled_program, ) diff --git a/extension/pybindings/pybindings.cpp b/extension/pybindings/pybindings.cpp index 78fca25ed02..17ea6c7af7b 100644 --- a/extension/pybindings/pybindings.cpp +++ b/extension/pybindings/pybindings.cpp @@ -23,6 +23,7 @@ #include #include #include +#include #include #include #include @@ -442,11 +443,12 @@ inline std::unique_ptr load_module_from_file( static constexpr size_t kDEFAULT_BUNDLED_INPUT_POOL_SIZE = 16 * 1024U; -struct PyBundledModule final { +struct PyBundledModule : public BundledModule { explicit PyBundledModule( const py::bytes& buffer, uint32_t bundled_input_pool_size) - : bundled_program_ptr_(buffer), + : BundledModule(buffer.cast().data()), + bundled_program_ptr_(buffer), program_ptr_(static_cast( bundled_program_flatbuffer::GetBundledProgram( get_bundled_program_ptr()) @@ -842,22 +844,20 @@ struct PyModule final { size_t testset_idx, double rtol = 1e-5, double atol = 1e-8) { - const void* bundled_program_ptr = m.get_bundled_program_ptr(); - auto& method = module_->get_method(method_name); - Error status = executorch::BUNDLED_PROGRAM_NAMESPACE::load_bundled_input( - method, bundled_program_ptr, testset_idx); + auto outputs = m.execute(method_name, testset_idx); + THROW_IF_ERROR( - status, - "load_bundled_input failed with status 0x%" PRIx32, - static_cast(status)); - py::list outputs = plan_execute(method_name); - status = executorch::BUNDLED_PROGRAM_NAMESPACE::verify_method_outputs( - method, bundled_program_ptr, testset_idx, rtol, atol); + outputs.error(), + "Execution failed with status 0x%" PRIx32, + static_cast(outputs.error())); + + auto status = m.verify_method_outputs(method_name, testset_idx, rtol, atol); THROW_IF_ERROR( status, "Result verification failed with status %" PRIu32, static_cast(status)); - return outputs; + + return get_outputs_as_py_list(outputs.get()); } py::list plan_execute( diff --git a/shim_et/xplat/executorch/extension/pybindings/pybindings.bzl b/shim_et/xplat/executorch/extension/pybindings/pybindings.bzl index 1616304c3ea..55a268d5d34 100644 --- a/shim_et/xplat/executorch/extension/pybindings/pybindings.bzl +++ b/shim_et/xplat/executorch/extension/pybindings/pybindings.bzl @@ -16,6 +16,7 @@ PORTABLE_MODULE_DEPS = [ "//executorch/extension/data_loader:buffer_data_loader", "//executorch/extension/data_loader:mmap_data_loader", "//executorch/extension/memory_allocator:malloc_memory_allocator", + "//executorch/extension/module:bundled_module", "//executorch/runtime/executor/test:test_backend_compiler_lib", "//executorch/devtools/etdump:etdump_flatcc", ] + get_all_cpu_backend_targets() @@ -28,6 +29,7 @@ ATEN_MODULE_DEPS = [ "//executorch/extension/data_loader:buffer_data_loader", "//executorch/extension/data_loader:mmap_data_loader", "//executorch/extension/memory_allocator:malloc_memory_allocator", + "//executorch/extension/module:bundled_module_aten", "//executorch/devtools/bundled_program:runtime_aten", "//executorch/runtime/executor/test:test_backend_compiler_lib_aten", "//executorch/devtools/etdump:etdump_flatcc",