Skip to content

Commit 9b95924

Browse files
Conarnarfacebook-github-bot
authored andcommitted
Add Pybindings for Program.h/cpp (pytorch#12016)
Summary: Today our python apis in executorch.runtime are implemented off of extension/pybindings which only offers a module api. We would like to migrate to having the lower level ET api exposed to python directly and then writing the module api in python. The first step to this is adding pybindings for Program. Bindings for the class Program and its methods num_methods and get_method_name were added. Test Plan: Tests were added to `extension/pybindings/test/make_test.py` 1. test_program_methods_one -- verifies num_methods and get_method_name works with one method 2. test_program_methods_multi -- verifies num_methods and get_method_name works with multiple methods 3. test_program_method_index_out_of_bounds -- verifies get_method_name raises a runtime error if index is out of bounds Rollback Plan: Reviewed By: JacobSzwejbka Differential Revision: D77388495 Pulled By: Conarnar
1 parent 3ba0466 commit 9b95924

File tree

3 files changed

+150
-0
lines changed

3 files changed

+150
-0
lines changed

extension/pybindings/portable_lib.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -44,10 +44,13 @@
4444
_load_for_executorch, # noqa: F401
4545
_load_for_executorch_from_buffer, # noqa: F401
4646
_load_for_executorch_from_bundled_program, # noqa: F401
47+
_load_program, # noqa: F401
48+
_load_program_from_buffer, # noqa: F401
4749
_reset_profile_results, # noqa: F401
4850
_unsafe_reset_threadpool, # noqa: F401
4951
BundledModule, # noqa: F401
5052
ExecuTorchModule, # noqa: F401
53+
ExecuTorchProgram, # noqa: F401
5154
MethodMeta, # noqa: F401
5255
Verification, # noqa: F401
5356
)

extension/pybindings/pybindings.cpp

Lines changed: 113 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -965,6 +965,99 @@ struct PyModule final {
965965
}
966966
};
967967

968+
inline std::unique_ptr<DataLoader> loader_from_buffer(
969+
const void* ptr,
970+
size_t ptr_len) {
971+
return std::make_unique<BufferDataLoader>(ptr, ptr_len);
972+
}
973+
974+
inline std::unique_ptr<DataLoader> loader_from_file(const std::string& path) {
975+
Result<MmapDataLoader> res = MmapDataLoader::from(
976+
path.c_str(), MmapDataLoader::MlockConfig::UseMlockIgnoreErrors);
977+
THROW_IF_ERROR(
978+
res.error(),
979+
"Failed to create MmapDataLoader from file %s, error: 0x:%" PRIx32,
980+
path.c_str(),
981+
static_cast<uint32_t>(res.error()));
982+
983+
return std::make_unique<MmapDataLoader>(std::move(res.get()));
984+
}
985+
986+
inline std::unique_ptr<Program> load_program(
987+
DataLoader* loader,
988+
Program::Verification program_verification) {
989+
Result<Program> res = Program::load(loader, program_verification);
990+
THROW_IF_ERROR(
991+
res.error(),
992+
"Failed to load program, error: 0x:%" PRIx32,
993+
static_cast<uint32_t>(res.error()));
994+
return std::make_unique<Program>(std::move(res.get()));
995+
}
996+
997+
struct PyProgram final {
998+
explicit PyProgram(
999+
const py::bytes& buffer,
1000+
Program::Verification program_verification =
1001+
Program::Verification::Minimal)
1002+
: loader_(loader_from_buffer(
1003+
buffer.cast<std::string_view>().data(),
1004+
py::len(buffer))),
1005+
program_(load_program(loader_.get(), program_verification)) {}
1006+
1007+
explicit PyProgram(
1008+
const std::string& path,
1009+
Program::Verification program_verification =
1010+
Program::Verification::Minimal)
1011+
: loader_(loader_from_file(path)),
1012+
program_(load_program(loader_.get(), program_verification)) {}
1013+
1014+
static std::unique_ptr<PyProgram> load_from_buffer(
1015+
const py::bytes& buffer,
1016+
Program::Verification program_verification =
1017+
Program::Verification::Minimal) {
1018+
return std::make_unique<PyProgram>(buffer, program_verification);
1019+
}
1020+
1021+
static std::unique_ptr<PyProgram> load_from_file(
1022+
const std::string& path,
1023+
Program::Verification program_verification =
1024+
Program::Verification::Minimal) {
1025+
return std::make_unique<PyProgram>(path, program_verification);
1026+
}
1027+
1028+
PyProgram(const PyProgram&) = delete;
1029+
PyProgram& operator=(const PyProgram&) = delete;
1030+
PyProgram(PyProgram&&) = default;
1031+
PyProgram& operator=(PyProgram&&) = default;
1032+
1033+
size_t num_methods() const {
1034+
return program_->num_methods();
1035+
}
1036+
1037+
std::string get_method_name(size_t method_index) const {
1038+
Result<const char*> res = program_->get_method_name(method_index);
1039+
THROW_IF_ERROR(
1040+
res.error(),
1041+
"Failed get method name, error: 0x:%" PRIx32,
1042+
static_cast<uint32_t>(res.error()));
1043+
return std::string(res.get());
1044+
}
1045+
1046+
std::string get_output_flattening_encoding(std::string method_name) const {
1047+
Result<const char*> res =
1048+
program_->get_output_flattening_encoding(method_name.c_str());
1049+
THROW_IF_ERROR(
1050+
res.error(),
1051+
"Failed get output flattening encoding, error: 0x:%" PRIx32,
1052+
static_cast<uint32_t>(res.error()));
1053+
return std::string(res.get());
1054+
}
1055+
1056+
private:
1057+
std::unique_ptr<DataLoader> loader_;
1058+
std::unique_ptr<Program> program_;
1059+
};
1060+
9681061
void create_profile_block(const std::string& name) {
9691062
EXECUTORCH_PROFILE_CREATE_BLOCK(name.c_str());
9701063
}
@@ -1151,6 +1244,26 @@ PYBIND11_MODULE(EXECUTORCH_PYTHON_MODULE_NAME, m) {
11511244
py::arg("index"),
11521245
call_guard)
11531246
.def("__repr__", &PyMethodMeta::repr, call_guard);
1247+
1248+
m.def(
1249+
"_load_program",
1250+
&PyProgram::load_from_file,
1251+
py::arg("path"),
1252+
py::arg("program_verification") = Program::Verification::Minimal,
1253+
call_guard);
1254+
m.def(
1255+
"_load_program_from_buffer",
1256+
&PyProgram::load_from_buffer,
1257+
py::arg("buffer"),
1258+
py::arg("program_verification") = Program::Verification::Minimal,
1259+
call_guard);
1260+
py::class_<PyProgram>(m, "ExecuTorchProgram")
1261+
.def("num_methods", &PyProgram::num_methods, call_guard)
1262+
.def(
1263+
"get_method_name",
1264+
&PyProgram::get_method_name,
1265+
py::arg("method_index"),
1266+
call_guard);
11541267
}
11551268

11561269
namespace {

extension/pybindings/test/make_test.py

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -168,6 +168,7 @@ def make_test( # noqa: C901
168168
subfunction of wrapper.
169169
"""
170170
load_fn: Callable = runtime._load_for_executorch_from_buffer
171+
load_prog_fn: Callable = runtime._load_program_from_buffer
171172

172173
def wrapper(tester: unittest.TestCase) -> None:
173174
######### TEST CASES #########
@@ -474,6 +475,36 @@ def test_unsupported_input_type(tester):
474475
# This should raise a Python error, not hit a fatal assert in the C++ code.
475476
tester.assertRaises(RuntimeError, executorch_module, inputs)
476477

478+
def test_program_methods_one(tester):
479+
# Create an ExecuTorch program from ModuleAdd.
480+
exported_program, _ = create_program(ModuleAdd())
481+
482+
# Use pybindings to load the program.
483+
executorch_program = load_prog_fn(exported_program.buffer)
484+
485+
tester.assertEqual(executorch_program.num_methods(), 1)
486+
tester.assertEqual(executorch_program.get_method_name(0), "forward")
487+
488+
def test_program_methods_multi(tester):
489+
# Create an ExecuTorch program from ModuleMulti.
490+
exported_program, _ = create_program(ModuleMulti())
491+
492+
# Use pybindings to load the program.
493+
executorch_program = load_prog_fn(exported_program.buffer)
494+
495+
tester.assertEqual(executorch_program.num_methods(), 2)
496+
tester.assertEqual(executorch_program.get_method_name(0), "forward")
497+
tester.assertEqual(executorch_program.get_method_name(1), "forward2")
498+
499+
def test_program_method_index_out_of_bounds(tester):
500+
# Create an ExecuTorch program from ModuleMulti.
501+
exported_program, _ = create_program(ModuleMulti())
502+
503+
# Use pybindings to load the program.
504+
executorch_program = load_prog_fn(exported_program.buffer)
505+
506+
tester.assertRaises(RuntimeError, executorch_program.get_method_name, 2)
507+
477508
######### RUN TEST CASES #########
478509
test_e2e(tester)
479510
test_multiple_entry(tester)
@@ -490,5 +521,8 @@ def test_unsupported_input_type(tester):
490521
test_bad_name(tester)
491522
test_verification_config(tester)
492523
test_unsupported_input_type(tester)
524+
test_program_methods_one(tester)
525+
test_program_methods_multi(tester)
526+
test_program_method_index_out_of_bounds(tester)
493527

494528
return wrapper

0 commit comments

Comments
 (0)