Skip to content

Commit 5cc5421

Browse files
authored
Add Pybindings for Program.h/cpp
Differential Revision: D77388495 Pull Request resolved: #12016
1 parent 9d599c9 commit 5cc5421

File tree

3 files changed

+140
-0
lines changed

3 files changed

+140
-0
lines changed

extension/pybindings/portable_lib.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -44,10 +44,13 @@
4444
_load_for_executorch, # noqa: F401
4545
_load_for_executorch_from_buffer, # noqa: F401
4646
_load_for_executorch_from_bundled_program, # noqa: F401
47+
_load_program, # noqa: F401
48+
_load_program_from_buffer, # noqa: F401
4749
_reset_profile_results, # noqa: F401
4850
_unsafe_reset_threadpool, # noqa: F401
4951
BundledModule, # noqa: F401
5052
ExecuTorchModule, # noqa: F401
53+
ExecuTorchProgram, # noqa: F401
5154
MethodMeta, # noqa: F401
5255
Verification, # noqa: F401
5356
)

extension/pybindings/pybindings.cpp

Lines changed: 103 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -965,6 +965,89 @@ struct PyModule final {
965965
}
966966
};
967967

968+
inline std::unique_ptr<DataLoader> loader_from_buffer(
969+
const void* ptr,
970+
size_t ptr_len) {
971+
return std::make_unique<BufferDataLoader>(ptr, ptr_len);
972+
}
973+
974+
inline std::unique_ptr<DataLoader> loader_from_file(const std::string& path) {
975+
Result<MmapDataLoader> res = MmapDataLoader::from(
976+
path.c_str(), MmapDataLoader::MlockConfig::UseMlockIgnoreErrors);
977+
THROW_IF_ERROR(
978+
res.error(),
979+
"Failed to create MmapDataLoader from file %s, error: 0x:%" PRIx32,
980+
path.c_str(),
981+
static_cast<uint32_t>(res.error()));
982+
983+
return std::make_unique<MmapDataLoader>(std::move(res.get()));
984+
}
985+
986+
inline std::unique_ptr<Program> load_program(
987+
DataLoader* loader,
988+
Program::Verification program_verification) {
989+
Result<Program> res = Program::load(loader, program_verification);
990+
THROW_IF_ERROR(
991+
res.error(),
992+
"Failed to load program, error: 0x:%" PRIx32,
993+
static_cast<uint32_t>(res.error()));
994+
return std::make_unique<Program>(std::move(res.get()));
995+
}
996+
997+
struct PyProgram final {
998+
explicit PyProgram(
999+
const py::bytes& buffer,
1000+
Program::Verification program_verification =
1001+
Program::Verification::Minimal)
1002+
: loader_(loader_from_buffer(
1003+
buffer.cast<std::string_view>().data(),
1004+
py::len(buffer))),
1005+
program_(load_program(loader_.get(), program_verification)) {}
1006+
1007+
explicit PyProgram(
1008+
const std::string& path,
1009+
Program::Verification program_verification =
1010+
Program::Verification::Minimal)
1011+
: loader_(loader_from_file(path)),
1012+
program_(load_program(loader_.get(), program_verification)) {}
1013+
1014+
static std::unique_ptr<PyProgram> load_from_buffer(
1015+
const py::bytes& buffer,
1016+
Program::Verification program_verification =
1017+
Program::Verification::Minimal) {
1018+
return std::make_unique<PyProgram>(buffer, program_verification);
1019+
}
1020+
1021+
static std::unique_ptr<PyProgram> load_from_file(
1022+
const std::string& path,
1023+
Program::Verification program_verification =
1024+
Program::Verification::Minimal) {
1025+
return std::make_unique<PyProgram>(path, program_verification);
1026+
}
1027+
1028+
PyProgram(const PyProgram&) = delete;
1029+
PyProgram& operator=(const PyProgram&) = delete;
1030+
PyProgram(PyProgram&&) = default;
1031+
PyProgram& operator=(PyProgram&&) = default;
1032+
1033+
size_t num_methods() const {
1034+
return program_->num_methods();
1035+
}
1036+
1037+
std::string get_method_name(size_t method_index) const {
1038+
Result<const char*> res = program_->get_method_name(method_index);
1039+
THROW_IF_ERROR(
1040+
res.error(),
1041+
"Failed get method name, error: 0x:%" PRIx32,
1042+
static_cast<uint32_t>(res.error()));
1043+
return std::string(res.get());
1044+
}
1045+
1046+
private:
1047+
std::unique_ptr<DataLoader> loader_;
1048+
std::unique_ptr<Program> program_;
1049+
};
1050+
9681051
void create_profile_block(const std::string& name) {
9691052
EXECUTORCH_PROFILE_CREATE_BLOCK(name.c_str());
9701053
}
@@ -1151,6 +1234,26 @@ PYBIND11_MODULE(EXECUTORCH_PYTHON_MODULE_NAME, m) {
11511234
py::arg("index"),
11521235
call_guard)
11531236
.def("__repr__", &PyMethodMeta::repr, call_guard);
1237+
1238+
m.def(
1239+
"_load_program",
1240+
&PyProgram::load_from_file,
1241+
py::arg("path"),
1242+
py::arg("program_verification") = Program::Verification::Minimal,
1243+
call_guard);
1244+
m.def(
1245+
"_load_program_from_buffer",
1246+
&PyProgram::load_from_buffer,
1247+
py::arg("buffer"),
1248+
py::arg("program_verification") = Program::Verification::Minimal,
1249+
call_guard);
1250+
py::class_<PyProgram>(m, "ExecuTorchProgram")
1251+
.def("num_methods", &PyProgram::num_methods, call_guard)
1252+
.def(
1253+
"get_method_name",
1254+
&PyProgram::get_method_name,
1255+
py::arg("method_index"),
1256+
call_guard);
11541257
}
11551258

11561259
namespace {

extension/pybindings/test/make_test.py

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -168,6 +168,7 @@ def make_test( # noqa: C901
168168
subfunction of wrapper.
169169
"""
170170
load_fn: Callable = runtime._load_for_executorch_from_buffer
171+
load_prog_fn: Callable = runtime._load_program_from_buffer
171172

172173
def wrapper(tester: unittest.TestCase) -> None:
173174
######### TEST CASES #########
@@ -474,6 +475,36 @@ def test_unsupported_input_type(tester):
474475
# This should raise a Python error, not hit a fatal assert in the C++ code.
475476
tester.assertRaises(RuntimeError, executorch_module, inputs)
476477

478+
def test_program_methods_one(tester):
479+
# Create an ExecuTorch program from ModuleAdd.
480+
exported_program, _ = create_program(ModuleAdd())
481+
482+
# Use pybindings to load the program.
483+
executorch_program = load_prog_fn(exported_program.buffer)
484+
485+
tester.assertEqual(executorch_program.num_methods(), 1)
486+
tester.assertEqual(executorch_program.get_method_name(0), "forward")
487+
488+
def test_program_methods_multi(tester):
489+
# Create an ExecuTorch program from ModuleMulti.
490+
exported_program, _ = create_program(ModuleMulti())
491+
492+
# Use pybindings to load the program.
493+
executorch_program = load_prog_fn(exported_program.buffer)
494+
495+
tester.assertEqual(executorch_program.num_methods(), 2)
496+
tester.assertEqual(executorch_program.get_method_name(0), "forward")
497+
tester.assertEqual(executorch_program.get_method_name(1), "forward2")
498+
499+
def test_program_method_index_out_of_bounds(tester):
500+
# Create an ExecuTorch program from ModuleMulti.
501+
exported_program, _ = create_program(ModuleMulti())
502+
503+
# Use pybindings to load the program.
504+
executorch_program = load_prog_fn(exported_program.buffer)
505+
506+
tester.assertRaises(RuntimeError, executorch_program.get_method_name, 2)
507+
477508
######### RUN TEST CASES #########
478509
test_e2e(tester)
479510
test_multiple_entry(tester)
@@ -490,5 +521,8 @@ def test_unsupported_input_type(tester):
490521
test_bad_name(tester)
491522
test_verification_config(tester)
492523
test_unsupported_input_type(tester)
524+
test_program_methods_one(tester)
525+
test_program_methods_multi(tester)
526+
test_program_method_index_out_of_bounds(tester)
493527

494528
return wrapper

0 commit comments

Comments
 (0)