Skip to content

Commit 33b0638

Browse files
Conarnarfacebook-github-bot
authored andcommitted
Added PyProgram methods and tests
Differential Revision: D77388495
1 parent 9f2c0f5 commit 33b0638

File tree

2 files changed

+58
-1
lines changed

2 files changed

+58
-1
lines changed

extension/pybindings/pybindings.cpp

Lines changed: 36 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1012,18 +1012,47 @@ struct PyProgram final {
10121012
program_(load_program(
10131013
loader_.get(),
10141014
program_verification)) {}
1015+
10151016
static std::unique_ptr<PyProgram> load_from_buffer(
10161017
const py::bytes& buffer,
10171018
Program::Verification program_verification =
10181019
Program::Verification::Minimal) {
10191020
return std::make_unique<PyProgram>(buffer, program_verification);
10201021
}
1022+
10211023
static std::unique_ptr<PyProgram> load_from_file(
10221024
const std::string& path,
10231025
Program::Verification program_verification =
10241026
Program::Verification::Minimal) {
10251027
return std::make_unique<PyProgram>(path, program_verification);
10261028
}
1029+
1030+
PyProgram(const PyProgram&) = delete;
1031+
PyProgram& operator=(const PyProgram&) = delete;
1032+
PyProgram(PyProgram&&) = default;
1033+
PyProgram& operator=(PyProgram&&) = default;
1034+
1035+
size_t num_methods() const {
1036+
return program_->num_methods();
1037+
}
1038+
1039+
std::string get_method_name(size_t method_index) const {
1040+
Result<const char*> res = program_->get_method_name(method_index);
1041+
THROW_IF_ERROR(
1042+
res.error(),
1043+
"Failed get method name, error: 0x:%" PRIx32,
1044+
static_cast<uint32_t>(res.error()));
1045+
return std::string(res.get());
1046+
}
1047+
1048+
std::string get_output_flattening_encoding(std::string method_name) const {
1049+
Result<const char*> res = program_->get_output_flattening_encoding(method_name.c_str());
1050+
THROW_IF_ERROR(
1051+
res.error(),
1052+
"Failed get output flattening encoding, error: 0x:%" PRIx32,
1053+
static_cast<uint32_t>(res.error()));
1054+
return std::string(res.get());
1055+
}
10271056
private:
10281057
std::unique_ptr<DataLoader> loader_;
10291058
std::unique_ptr<Program> program_;
@@ -1231,7 +1260,13 @@ PYBIND11_MODULE(EXECUTORCH_PYTHON_MODULE_NAME, m) {
12311260
py::arg("program_verification") =
12321261
Program::Verification::Minimal,
12331262
call_guard);
1234-
py::class_<PyProgram>(m, "ExecuTorchProgram");
1263+
py::class_<PyProgram>(m, "ExecuTorchProgram")
1264+
.def("num_methods", &PyProgram::num_methods, call_guard)
1265+
.def(
1266+
"get_method_name",
1267+
&PyProgram::get_method_name,
1268+
py::arg("method_index"),
1269+
call_guard);
12351270
}
12361271

12371272
namespace {

extension/pybindings/test/make_test.py

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -168,6 +168,7 @@ def make_test( # noqa: C901
168168
subfunction of wrapper.
169169
"""
170170
load_fn: Callable = runtime._load_for_executorch_from_buffer
171+
load_prog_fn: Callable = runtime._load_program_from_buffer
171172

172173
def wrapper(tester: unittest.TestCase) -> None:
173174
######### TEST CASES #########
@@ -474,6 +475,24 @@ def test_unsupported_input_type(tester):
474475
# This should raise a Python error, not hit a fatal assert in the C++ code.
475476
tester.assertRaises(RuntimeError, executorch_module, inputs)
476477

478+
def test_program_methods_one(tester):
479+
exported_program, _ = create_program(ModuleAdd())
480+
executorch_program = load_prog_fn(exported_program.buffer)
481+
tester.assertEqual(executorch_program.num_methods(), 1)
482+
tester.assertEqual(executorch_program.get_method_name(0), "forward")
483+
484+
def test_program_methods_multi(tester):
485+
exported_program, _ = create_program(ModuleMulti())
486+
executorch_program = load_prog_fn(exported_program.buffer)
487+
tester.assertEqual(executorch_program.num_methods(), 2)
488+
tester.assertEqual(executorch_program.get_method_name(0), "forward")
489+
tester.assertEqual(executorch_program.get_method_name(1), "forward2")
490+
491+
def test_program_method_index_out_of_bounds(tester):
492+
exported_program, _ = create_program(ModuleMulti())
493+
executorch_program = load_prog_fn(exported_program.buffer)
494+
tester.assertRaises(RuntimeError, executorch_program.get_method_name, 2)
495+
477496
######### RUN TEST CASES #########
478497
test_e2e(tester)
479498
test_multiple_entry(tester)
@@ -490,5 +509,8 @@ def test_unsupported_input_type(tester):
490509
test_bad_name(tester)
491510
test_verification_config(tester)
492511
test_unsupported_input_type(tester)
512+
test_program_methods_one(tester)
513+
test_program_methods_multi(tester)
514+
test_program_method_index_out_of_bounds(tester)
493515

494516
return wrapper

0 commit comments

Comments
 (0)