Skip to content

Commit 8642496

Browse files
GregoryComerfacebook-github-bot
authored andcommitted
Handle unsupported pybind inputs non-fatally (#10670)
Summary: Currently, passing unsupported python types to pybind method execution, such as lists, dicts, or tuples, will crash the kernel due to hitting an assert. This PR updates the logic to raise an exception, which gets nicely bubbled up to the notebook. This gives the user a nicer error message and does not crash the bento/jupyter process. Differential Revision: D74118509
1 parent 94f7b10 commit 8642496

File tree

2 files changed

+14
-1
lines changed

2 files changed

+14
-1
lines changed

extension/pybindings/pybindings.cpp

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -757,7 +757,9 @@ struct PyModule final {
757757
} else if (py::isinstance<py::int_>(python_input)) {
758758
cpp_inputs.push_back(EValue(py::cast<int64_t>(python_input)));
759759
} else {
760-
ET_ASSERT_UNREACHABLE_MSG("Unsupported pytype: %s", type_str.c_str());
760+
throw std::runtime_error(
761+
"Unsupported python type " + type_str +
762+
". Ensure that inputs are passed as a flat list of tensors.");
761763
}
762764
}
763765

extension/pybindings/test/make_test.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -464,6 +464,16 @@ def test_verification_config(tester) -> None:
464464

465465
tester.assertEqual(str(expected), str(executorch_output))
466466

467+
def test_unsupported_input_type(tester):
468+
exported_program, inputs = create_program(ModuleAdd())
469+
executorch_module = load_fn(exported_program.buffer)
470+
471+
# Pass an unsupported input type to the module.
472+
inputs = ([*inputs],)
473+
474+
# This should raise a Python error, not hit a fatal assert in the C++ code.
475+
tester.assertRaises(RuntimeError, executorch_module, inputs)
476+
467477
######### RUN TEST CASES #########
468478
test_e2e(tester)
469479
test_multiple_entry(tester)
@@ -479,5 +489,6 @@ def test_verification_config(tester) -> None:
479489
test_method_meta(tester)
480490
test_bad_name(tester)
481491
test_verification_config(tester)
492+
test_unsupported_input_type(tester)
482493

483494
return wrapper

0 commit comments

Comments
 (0)