Skip to content

Commit 2f07dad

Browse files
committed
[ExecuTorch][#10447] Extend PyBundledModule with extension.BundledModule
# Context This issue is a step of #9638. In #9638, we want to have `extension.Module` as the single source of implementation in `pybindings`, which means that `pybindings.PyModule` should use `extension.Module` rather than its own `pybindings.Module`. # Proposal Now that we have `extension.BundledModule` ready, we want to test it out by having our existing `PyBundledModule` to extend it, and let `verify_result_with_bundled_expected_output` to use it, so that we can test out the whole thing with https://github.com/pytorch/executorch/blob/fb45e19055a92d2a91a4d4b7008e135232cbb14b/devtools/bundled_program/test/test_end2end.py ghstack-source-id: 283524132 Differential Revision: [D73564127](https://our.internmc.facebook.com/intern/diff/D73564127/) [ghstack-poisoned]
1 parent cbd3874 commit 2f07dad

File tree

3 files changed

+83
-121
lines changed

3 files changed

+83
-121
lines changed

devtools/bundled_program/test/test_end2end.py

Lines changed: 1 addition & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -5,21 +5,7 @@
55
# LICENSE file in the root directory of this source tree.
66

77
# flake8: noqa: F401
8-
import functools
9-
import inspect
10-
import os
11-
import random
128
import unittest
13-
from typing import Callable, Dict, Optional, Tuple, Type
14-
15-
import executorch.exir as exir
16-
17-
import executorch.exir.control_flow as control_flow
18-
19-
# @manual=//executorch/extension/pytree:pybindings
20-
import executorch.extension.pytree as pytree
21-
22-
import torch
239

2410
from executorch.devtools.bundled_program.core import BundledProgram
2511
from executorch.devtools.bundled_program.serialize import (
@@ -35,8 +21,6 @@
3521
try:
3622
from executorch.extension.pybindings.portable_lib import (
3723
_load_bundled_program_from_buffer,
38-
_load_for_executorch_from_buffer,
39-
_load_for_executorch_from_bundled_program,
4024
)
4125

4226
kernel_mode = "lean"
@@ -47,8 +31,6 @@
4731
try:
4832
from executorch.extension.pybindings.aten_lib import ( # @manual=//executorch/extension/pybindings:aten_lib
4933
_load_bundled_program_from_buffer,
50-
_load_for_executorch_from_buffer,
51-
_load_for_executorch_from_bundled_program,
5234
)
5335

5436
assert kernel_mode is None
@@ -75,19 +57,8 @@ def test_sample_model_e2e(self):
7557
bundled_program_buffer
7658
)
7759

78-
executorch_module = _load_for_executorch_from_bundled_program(
79-
executorch_bundled_program
80-
)
81-
8260
for method_name in eager_model.method_names:
83-
executorch_module.load_bundled_input(
84-
executorch_bundled_program,
85-
method_name,
86-
0,
87-
)
88-
executorch_module.plan_execute(method_name)
89-
executorch_module.verify_result_with_bundled_expected_output(
90-
executorch_bundled_program,
61+
executorch_bundled_program.verify_result_with_bundled_expected_output(
9162
method_name,
9263
0,
9364
)

extension/pybindings/pybindings.cpp

Lines changed: 80 additions & 91 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@
2323
#include <executorch/extension/data_loader/buffer_data_loader.h>
2424
#include <executorch/extension/data_loader/mmap_data_loader.h>
2525
#include <executorch/extension/memory_allocator/malloc_memory_allocator.h>
26+
#include <executorch/extension/module/bundled_module.h>
2627
#include <executorch/extension/threadpool/threadpool.h>
2728
#include <executorch/runtime/backend/interface.h>
2829
#include <executorch/runtime/core/data_loader.h>
@@ -440,13 +441,54 @@ inline std::unique_ptr<Module> load_module_from_file(
440441
program_verification);
441442
}
442443

444+
inline py::list get_outputs_as_py_list(
445+
const std::vector<EValue>& outputs,
446+
bool clone_outputs = true) {
447+
const auto outputs_size = outputs.size();
448+
py::list list(outputs_size);
449+
for (size_t i = 0; i < outputs_size; ++i) {
450+
auto& v = outputs[i];
451+
if (Tag::None == v.tag) {
452+
list[i] = py::none();
453+
} else if (Tag::Int == v.tag) {
454+
list[i] = py::cast(v.toInt());
455+
} else if (Tag::Double == v.tag) {
456+
list[i] = py::cast(v.toDouble());
457+
} else if (Tag::Bool == v.tag) {
458+
list[i] = py::cast(v.toBool());
459+
} else if (Tag::String == v.tag) {
460+
list[i] = py::cast(std::string(v.toString().data()));
461+
} else if (Tag::Tensor == v.tag) {
462+
#ifdef USE_ATEN_LIB
463+
// Clone so the outputs in python do not share a lifetime with the
464+
// module object
465+
if (clone_outputs) {
466+
list[i] = py::cast(v.toTensor().clone());
467+
} else {
468+
list[i] = py::cast(v.toTensor());
469+
}
470+
#else
471+
if (clone_outputs) {
472+
list[i] = py::cast(alias_attensor_to_etensor(v.toTensor()).clone());
473+
} else {
474+
list[i] = py::cast(alias_attensor_to_etensor(v.toTensor()));
475+
}
476+
#endif
477+
} else {
478+
ET_ASSERT_UNREACHABLE_MSG("Invalid model output type");
479+
}
480+
}
481+
return list;
482+
}
483+
443484
static constexpr size_t kDEFAULT_BUNDLED_INPUT_POOL_SIZE = 16 * 1024U;
444485

445-
struct PyBundledModule final {
486+
struct PyBundledModule : public BundledModule {
446487
explicit PyBundledModule(
447488
const py::bytes& buffer,
448489
uint32_t bundled_input_pool_size)
449-
: bundled_program_ptr_(buffer),
490+
: BundledModule(buffer.cast<std::string_view>().data()),
491+
bundled_program_ptr_(buffer),
450492
program_ptr_(static_cast<const void*>(
451493
bundled_program_flatbuffer::GetBundledProgram(
452494
get_bundled_program_ptr())
@@ -475,6 +517,32 @@ struct PyBundledModule final {
475517
return program_len_;
476518
}
477519

520+
py::list verify_result_with_bundled_expected_output(
521+
const std::string& method_name,
522+
size_t testset_idx,
523+
double rtol = 1e-5,
524+
double atol = 1e-8) {
525+
// Execute the method
526+
auto result = BundledModule::execute(method_name, testset_idx);
527+
if (!result.ok()) {
528+
THROW_IF_ERROR(
529+
result.error(),
530+
"Method execution failed with status 0x%" PRIx32,
531+
static_cast<uint32_t>(result.error()));
532+
}
533+
534+
// Convert outputs to py::list
535+
const auto& outputs = result.get();
536+
py::list py_outputs = get_outputs_as_py_list(outputs);
537+
538+
Error status = BundledModule::verify_method_outputs(method_name, testset_idx, rtol, atol);
539+
THROW_IF_ERROR(
540+
status,
541+
"Result verification failed with status %" PRIu32,
542+
static_cast<uint32_t>(status));
543+
return py_outputs;
544+
}
545+
478546
private:
479547
// Store the bytes object instead of a raw pointer so that this module will
480548
// keep the bytes alive.
@@ -791,7 +859,7 @@ struct PyModule final {
791859
}
792860

793861
py::list forward_single_input(
794-
const torch::Tensor& inputTensor,
862+
const torch::Tensor& inputTensor,
795863
bool clone_outputs = true) {
796864
py::list py_list;
797865
py_list.append(py::cast(inputTensor));
@@ -831,43 +899,6 @@ struct PyModule final {
831899
}
832900
}
833901

834-
void load_bundled_input(
835-
PyBundledModule& m,
836-
const std::string method_name,
837-
size_t testset_idx) {
838-
const void* bundled_program_ptr = m.get_bundled_program_ptr();
839-
Error status = executorch::BUNDLED_PROGRAM_NAMESPACE::load_bundled_input(
840-
module_->get_method(method_name), bundled_program_ptr, testset_idx);
841-
THROW_IF_ERROR(
842-
status,
843-
"load_bundled_input failed with status 0x%" PRIx32,
844-
static_cast<uint32_t>(status));
845-
}
846-
847-
py::list verify_result_with_bundled_expected_output(
848-
PyBundledModule& m,
849-
const std::string method_name,
850-
size_t testset_idx,
851-
double rtol = 1e-5,
852-
double atol = 1e-8) {
853-
const void* bundled_program_ptr = m.get_bundled_program_ptr();
854-
auto& method = module_->get_method(method_name);
855-
Error status = executorch::BUNDLED_PROGRAM_NAMESPACE::load_bundled_input(
856-
method, bundled_program_ptr, testset_idx);
857-
THROW_IF_ERROR(
858-
status,
859-
"load_bundled_input failed with status 0x%" PRIx32,
860-
static_cast<uint32_t>(status));
861-
py::list outputs = plan_execute(method_name);
862-
status = executorch::BUNDLED_PROGRAM_NAMESPACE::verify_method_outputs(
863-
method, bundled_program_ptr, testset_idx, rtol, atol);
864-
THROW_IF_ERROR(
865-
status,
866-
"Result verification failed with status %" PRIu32,
867-
static_cast<uint32_t>(status));
868-
return outputs;
869-
}
870-
871902
py::list plan_execute(
872903
const std::string method_name,
873904
bool clone_outputs = true) {
@@ -890,46 +921,6 @@ struct PyModule final {
890921
return get_outputs_as_py_list(outputs, clone_outputs);
891922
}
892923

893-
py::list get_outputs_as_py_list(
894-
const std::vector<EValue>& outputs,
895-
bool clone_outputs = true) {
896-
const auto outputs_size = outputs.size();
897-
py::list list(outputs_size);
898-
for (size_t i = 0; i < outputs_size; ++i) {
899-
auto& v = outputs[i];
900-
if (Tag::None == v.tag) {
901-
list[i] = py::none();
902-
} else if (Tag::Int == v.tag) {
903-
list[i] = py::cast(v.toInt());
904-
} else if (Tag::Double == v.tag) {
905-
list[i] = py::cast(v.toDouble());
906-
} else if (Tag::Bool == v.tag) {
907-
list[i] = py::cast(v.toBool());
908-
} else if (Tag::String == v.tag) {
909-
list[i] = py::cast(std::string(v.toString().data()));
910-
} else if (Tag::Tensor == v.tag) {
911-
#ifdef USE_ATEN_LIB
912-
// Clone so the outputs in python do not share a lifetime with the
913-
// module object
914-
if (clone_outputs) {
915-
list[i] = py::cast(v.toTensor().clone());
916-
} else {
917-
list[i] = py::cast(v.toTensor());
918-
}
919-
#else
920-
if (clone_outputs) {
921-
list[i] = py::cast(alias_attensor_to_etensor(v.toTensor()).clone());
922-
} else {
923-
list[i] = py::cast(alias_attensor_to_etensor(v.toTensor()));
924-
}
925-
#endif
926-
} else {
927-
ET_ASSERT_UNREACHABLE_MSG("Invalid model output type");
928-
}
929-
}
930-
return list;
931-
}
932-
933924
std::unique_ptr<PyMethodMeta> method_meta(const std::string method_name) {
934925
auto& method = module_->get_method(method_name);
935926
return std::make_unique<PyMethodMeta>(module_, method.method_meta());
@@ -1089,16 +1080,6 @@ PYBIND11_MODULE(EXECUTORCH_PYTHON_MODULE_NAME, m) {
10891080
call_guard);
10901081

10911082
py::class_<PyModule>(m, "ExecuTorchModule")
1092-
.def("load_bundled_input", &PyModule::load_bundled_input, call_guard)
1093-
.def(
1094-
"verify_result_with_bundled_expected_output",
1095-
&PyModule::verify_result_with_bundled_expected_output,
1096-
py::arg("bundle"),
1097-
py::arg("method_name"),
1098-
py::arg("testset_idx"),
1099-
py::arg("rtol") = 1e-5,
1100-
py::arg("atol") = 1e-8,
1101-
call_guard)
11021083
.def(
11031084
"plan_execute",
11041085
&PyModule::plan_execute,
@@ -1144,7 +1125,15 @@ PYBIND11_MODULE(EXECUTORCH_PYTHON_MODULE_NAME, m) {
11441125
py::arg("clone_outputs") = true,
11451126
call_guard);
11461127

1147-
py::class_<PyBundledModule>(m, "BundledModule");
1128+
py::class_<PyBundledModule>(m, "BundledModule").def(
1129+
"verify_result_with_bundled_expected_output",
1130+
&PyBundledModule::verify_result_with_bundled_expected_output,
1131+
py::arg("method_name"),
1132+
py::arg("testset_idx"),
1133+
py::arg("rtol") = 1e-5,
1134+
py::arg("atol") = 1e-8,
1135+
call_guard);
1136+
11481137
py::class_<PyTensorInfo>(m, "TensorInfo")
11491138
.def("sizes", &PyTensorInfo::sizes, call_guard)
11501139
.def("dtype", &PyTensorInfo::dtype, call_guard)

shim_et/xplat/executorch/extension/pybindings/pybindings.bzl

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@ PORTABLE_MODULE_DEPS = [
1616
"//executorch/extension/data_loader:buffer_data_loader",
1717
"//executorch/extension/data_loader:mmap_data_loader",
1818
"//executorch/extension/memory_allocator:malloc_memory_allocator",
19+
"//executorch/extension/module:bundled_module",
1920
"//executorch/runtime/executor/test:test_backend_compiler_lib",
2021
"//executorch/devtools/etdump:etdump_flatcc",
2122
] + get_all_cpu_backend_targets()
@@ -28,6 +29,7 @@ ATEN_MODULE_DEPS = [
2829
"//executorch/extension/data_loader:buffer_data_loader",
2930
"//executorch/extension/data_loader:mmap_data_loader",
3031
"//executorch/extension/memory_allocator:malloc_memory_allocator",
32+
"//executorch/extension/module:bundled_module_aten",
3133
"//executorch/devtools/bundled_program:runtime_aten",
3234
"//executorch/runtime/executor/test:test_backend_compiler_lib_aten",
3335
"//executorch/devtools/etdump:etdump_flatcc",

0 commit comments

Comments
 (0)