Skip to content

Commit 7a1297a

Browse files
committed
Add mapping from C++ program::verification to Python (#5915)
Summary: As titled. This enables `portable_lib._load_for_executorch[_from_buffer]` to accept `Program::Verification` argument. See added test, now we can do something like: ``` from executorch.extension.pybindings.portable_lib import Verification module = load_fn( exported_program.buffer, enable_etdump=False, debug_buffer_size=0, program_verification=Verification.Minimal, ) ``` Pull Request resolved: #5915 Test Plan: See unit test Reviewed By: dbort Differential Revision: D63987538 Pulled By: larryliu0820 fbshipit-source-id: b68d8d1149e2d46b90544679707f420179e72b19
1 parent 859928d commit 7a1297a

File tree

5 files changed

+120
-35
lines changed

5 files changed

+120
-35
lines changed

extension/pybindings/portable_lib.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,7 @@
4545
_reset_profile_results, # noqa: F401
4646
BundledModule, # noqa: F401
4747
ExecuTorchModule, # noqa: F401
48+
Verification, # noqa: F401
4849
)
4950

5051
# Clean up so that `dir(portable_lib)` is the same as `dir(_portable_lib)`

extension/pybindings/pybindings.cpp

Lines changed: 51 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -168,13 +168,15 @@ class Module final {
168168
explicit Module(
169169
std::unique_ptr<DataLoader> loader,
170170
std::unique_ptr<ETDumpGen> tracer = nullptr,
171-
size_t debug_buffer_size = 0)
171+
size_t debug_buffer_size = 0,
172+
Program::Verification program_verification =
173+
Program::Verification::InternalConsistency)
172174
: loader_(std::move(loader)),
173175
event_tracer_(std::move(tracer)),
174176
debug_buffer_size_(debug_buffer_size) {
175177
::executorch::runtime::runtime_init();
176-
Result<Program> program = Program::load(
177-
loader_.get(), Program::Verification::InternalConsistency);
178+
Result<Program> program =
179+
Program::load(loader_.get(), program_verification);
178180
THROW_IF_ERROR(
179181
program.error(),
180182
"loading program failed with error: 0x%" PRIx32,
@@ -386,19 +388,22 @@ inline std::unique_ptr<Module> load_module_from_buffer(
386388
const void* ptr,
387389
size_t ptr_len,
388390
bool enable_etdump,
389-
size_t debug_buffer_size) {
391+
size_t debug_buffer_size,
392+
Program::Verification program_verification) {
390393
EXECUTORCH_SCOPE_PROF("load_module_from_buffer");
391394
auto loader = std::make_unique<BufferDataLoader>(ptr, ptr_len);
392395
return std::make_unique<Module>(
393396
std::move(loader),
394397
enable_etdump ? std::make_unique<torch::executor::ETDumpGen>() : nullptr,
395-
debug_buffer_size);
398+
debug_buffer_size,
399+
program_verification);
396400
}
397401

398402
inline std::unique_ptr<Module> load_module_from_file(
399403
const std::string& path,
400404
bool enable_etdump,
401-
size_t debug_buffer_size) {
405+
size_t debug_buffer_size,
406+
Program::Verification program_verification) {
402407
EXECUTORCH_SCOPE_PROF("load_module_from_file");
403408

404409
Result<MmapDataLoader> res = MmapDataLoader::from(
@@ -413,7 +418,8 @@ inline std::unique_ptr<Module> load_module_from_file(
413418
return std::make_unique<Module>(
414419
std::move(loader),
415420
enable_etdump ? std::make_unique<torch::executor::ETDumpGen>() : nullptr,
416-
debug_buffer_size);
421+
debug_buffer_size,
422+
program_verification);
417423
}
418424

419425
static constexpr size_t kDEFAULT_BUNDLED_INPUT_POOL_SIZE = 16 * 1024U;
@@ -576,30 +582,41 @@ struct PyModule final {
576582
explicit PyModule(
577583
const py::bytes& buffer,
578584
bool enable_etdump,
579-
size_t debug_buffer_size = 0)
585+
size_t debug_buffer_size = 0,
586+
Program::Verification program_verification =
587+
Program::Verification::InternalConsistency)
580588
: module_(load_module_from_buffer(
581589
buffer.cast<std::string_view>().data(),
582590
py::len(buffer),
583591
enable_etdump,
584-
debug_buffer_size)) {}
592+
debug_buffer_size,
593+
program_verification)) {}
585594

586595
explicit PyModule(
587596
const void* ptr,
588597
size_t ptr_len,
589598
bool enable_etdump,
590-
size_t debug_buffer_size = 0)
599+
size_t debug_buffer_size = 0,
600+
Program::Verification program_verification =
601+
Program::Verification::InternalConsistency)
591602
: module_(load_module_from_buffer(
592603
ptr,
593604
ptr_len,
594605
enable_etdump,
595-
debug_buffer_size)) {}
606+
debug_buffer_size,
607+
program_verification)) {}
596608

597609
explicit PyModule(
598610
const std::string& path,
599611
bool enable_etdump,
600-
size_t debug_buffer_size = 0)
601-
: module_(load_module_from_file(path, enable_etdump, debug_buffer_size)) {
602-
}
612+
size_t debug_buffer_size = 0,
613+
Program::Verification program_verification =
614+
Program::Verification::InternalConsistency)
615+
: module_(load_module_from_file(
616+
path,
617+
enable_etdump,
618+
debug_buffer_size,
619+
program_verification)) {}
603620

604621
PyModule(const PyModule&) = delete;
605622
PyModule& operator=(const PyModule&) = delete;
@@ -610,14 +627,20 @@ struct PyModule final {
610627
static std::unique_ptr<PyModule> load_from_buffer(
611628
const py::bytes& buffer,
612629
bool enable_etdump,
613-
size_t debug_buffer_size = 0) {
614-
return std::make_unique<PyModule>(buffer, enable_etdump, debug_buffer_size);
630+
size_t debug_buffer_size = 0,
631+
Program::Verification program_verification =
632+
Program::Verification::InternalConsistency) {
633+
return std::make_unique<PyModule>(
634+
buffer, enable_etdump, debug_buffer_size, program_verification);
615635
}
616636
static std::unique_ptr<PyModule> load_from_file(
617637
const std::string& path,
618638
bool enable_etdump,
619-
size_t debug_buffer_size = 0) {
620-
return std::make_unique<PyModule>(path, enable_etdump, debug_buffer_size);
639+
size_t debug_buffer_size = 0,
640+
Program::Verification program_verification =
641+
Program::Verification::InternalConsistency) {
642+
return std::make_unique<PyModule>(
643+
path, enable_etdump, debug_buffer_size, program_verification);
621644
}
622645

623646
static std::unique_ptr<PyModule> load_from_bundled_program(
@@ -934,19 +957,29 @@ PYBIND11_MODULE(EXECUTORCH_PYTHON_MODULE_NAME, m) {
934957
// Redirects cout and cerr for function calls this guards to the python env.
935958
auto call_guard = py::
936959
call_guard<py::scoped_ostream_redirect, py::scoped_estream_redirect>();
960+
961+
// Bind the verification enum to python.
962+
py::enum_<Program::Verification>(m, "Verification")
963+
.value("Minimal", Program::Verification::Minimal)
964+
.value("InternalConsistency", Program::Verification::InternalConsistency);
965+
937966
m.def(
938967
"_load_for_executorch",
939968
PyModule::load_from_file,
940969
py::arg("path"),
941970
py::arg("enable_etdump") = false,
942971
py::arg("debug_buffer_size") = 0,
972+
py::arg("program_verification") =
973+
Program::Verification::InternalConsistency,
943974
call_guard);
944975
m.def(
945976
"_load_for_executorch_from_buffer",
946977
&PyModule::load_from_buffer,
947978
py::arg("buffer"),
948979
py::arg("enable_etdump") = false,
949980
py::arg("debug_buffer_size") = 0,
981+
py::arg("program_verification") =
982+
Program::Verification::InternalConsistency,
950983
call_guard);
951984
m.def(
952985
"_load_for_executorch_from_bundled_program",

extension/pybindings/pybindings.pyi

Lines changed: 21 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -7,10 +7,22 @@
77
# pyre-strict
88
from __future__ import annotations
99

10-
from typing import Any, Dict, List, Optional, Sequence, Tuple
10+
from typing import Any, Dict, Enum, List, Optional, Sequence, Tuple
1111

1212
from executorch.exir._warnings import experimental
1313

14+
@experimental("This API is experimental and subject to change without notice.")
15+
class Verification(Enum):
16+
"""Verification maps C++ Program::Verification to Python.
17+
18+
.. warning::
19+
20+
This API is experimental and subject to change without notice.
21+
"""
22+
23+
Minimal: ...
24+
InternalConsistency: ...
25+
1426
@experimental("This API is experimental and subject to change without notice.")
1527
class ExecuTorchModule:
1628
"""ExecuTorchModule is a Python wrapper around a C++ ExecuTorch program.
@@ -125,7 +137,10 @@ class MethodMeta:
125137

126138
@experimental("This API is experimental and subject to change without notice.")
127139
def _load_for_executorch(
128-
path: str, enable_etdump: bool = False, debug_buffer_size: int = 0
140+
path: str,
141+
enable_etdump: bool = False,
142+
debug_buffer_size: int = 0,
143+
program_verification: Verification = Verification.InternalConsistency,
129144
) -> ExecuTorchModule:
130145
"""Load an ExecuTorch Program from a file.
131146
@@ -148,7 +163,10 @@ def _load_for_executorch(
148163

149164
@experimental("This API is experimental and subject to change without notice.")
150165
def _load_for_executorch_from_buffer(
151-
buffer: bytes, enable_etdump: bool = False, debug_buffer_size: int = 0
166+
buffer: bytes,
167+
enable_etdump: bool = False,
168+
debug_buffer_size: int = 0,
169+
program_verification: Verification = Verification.InternalConsistency,
152170
) -> ExecuTorchModule:
153171
"""Same as _load_for_executorch, but takes a byte buffer instead of a file path.
154172

extension/pybindings/test/make_test.py

Lines changed: 39 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
# pyre-unsafe
88

99
import unittest
10+
from types import ModuleType
1011
from typing import Any, Callable, Optional, Tuple
1112

1213
import torch
@@ -17,7 +18,7 @@
1718

1819
def make_test( # noqa: C901
1920
tester: unittest.TestCase,
20-
load_fn: Callable,
21+
runtime: ModuleType,
2122
) -> Callable[[unittest.TestCase], None]:
2223
"""
2324
Returns a function that operates as a test case within a unittest.TestCase class.
@@ -26,6 +27,7 @@ def make_test( # noqa: C901
2627
which will all have different load functions. In this case each individual test case is a
2728
subfunction of wrapper.
2829
"""
30+
load_fn: Callable = runtime._load_for_executorch_from_buffer
2931

3032
def wrapper(tester: unittest.TestCase) -> None:
3133
class ModuleAdd(torch.nn.Module):
@@ -343,6 +345,40 @@ def test_method_meta(tester) -> None:
343345
tester.assertEqual(output_tensor.nbytes(), 16)
344346
tester.assertEqual(str(output_tensor), tensor_info)
345347

348+
def test_bad_name(tester) -> None:
349+
# Create an ExecuTorch program from ModuleAdd.
350+
# pyre-fixme[16]: Callable `make_test` has no attribute `wrapper`.
351+
exported_program, inputs = create_program(ModuleAdd())
352+
353+
# Use pybindings to load and execute the program.
354+
executorch_module = load_fn(exported_program.buffer)
355+
# Invoke the callable on executorch_module instead of calling module.forward.
356+
with tester.assertRaises(RuntimeError):
357+
executorch_module.run_method("not_a_real_method", inputs)
358+
359+
def test_verification_config(tester) -> None:
360+
# Create an ExecuTorch program from ModuleAdd.
361+
# pyre-fixme[16]: Callable `make_test` has no attribute `wrapper`.
362+
exported_program, inputs = create_program(ModuleAdd())
363+
Verification = runtime.Verification
364+
365+
# Use pybindings to load and execute the program.
366+
for config in [Verification.Minimal, Verification.InternalConsistency]:
367+
executorch_module = load_fn(
368+
exported_program.buffer,
369+
enable_etdump=False,
370+
debug_buffer_size=0,
371+
program_verification=config,
372+
)
373+
374+
executorch_output = executorch_module.forward(inputs)[0]
375+
376+
# The test module adds the two inputs, so its output should be the same
377+
# as adding them directly.
378+
expected = inputs[0] + inputs[1]
379+
380+
tester.assertEqual(str(expected), str(executorch_output))
381+
346382
######### RUN TEST CASES #########
347383
test_e2e(tester)
348384
test_multiple_entry(tester)
@@ -353,5 +389,7 @@ def test_method_meta(tester) -> None:
353389
test_quantized_ops(tester)
354390
test_constant_output_not_memory_planned(tester)
355391
test_method_meta(tester)
392+
test_bad_name(tester)
393+
test_verification_config(tester)
356394

357395
return wrapper

extension/pybindings/test/test_pybindings.py

Lines changed: 8 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -10,24 +10,19 @@
1010

1111
kernel_mode = None # either aten mode or portable mode
1212
try:
13-
from executorch.extension.pybindings.portable_lib import (
14-
_load_for_executorch_from_buffer,
15-
)
13+
from executorch.extension.pybindings import portable_lib as runtime
1614

1715
kernel_mode = "portable"
1816
except Exception:
1917
print("can't load portable lib")
2018

21-
try:
22-
from executorch.extension.pybindings.aten_lib import ( # noqa: F811
23-
_load_for_executorch_from_buffer,
24-
)
25-
26-
assert kernel_mode is None
19+
if kernel_mode is None:
20+
try:
21+
from executorch.extension.pybindings import aten_lib as runtime # noqa: F811
2722

28-
kernel_mode = "aten"
29-
except Exception:
30-
print("can't load aten lib")
23+
kernel_mode = "aten"
24+
except Exception:
25+
print("can't load aten lib")
3126

3227
assert kernel_mode is not None
3328

@@ -37,4 +32,4 @@
3732

3833
class PybindingsTest(unittest.TestCase):
3934
def test(self):
40-
make_test(self, _load_for_executorch_from_buffer)(self)
35+
make_test(self, runtime)(self)

0 commit comments

Comments
 (0)