From d51ccc13d9167bd5b0f0e6cbcebda50c17226021 Mon Sep 17 00:00:00 2001 From: Chris Thompson Date: Tue, 2 Sep 2025 12:47:26 -0700 Subject: [PATCH] Add pybindings for attribute tensors (#13579) Summary: Add pybindings for grabbing attribute tensor information from method meta Reviewed By: JacobSzwejbka Differential Revision: D80631040 --- extension/pybindings/pybindings.cpp | 21 +++++++++++++++ extension/pybindings/pybindings.pyi | 9 +++++++ extension/pybindings/test/test_pybindings.py | 27 ++++++++++++++++++-- 3 files changed, 55 insertions(+), 2 deletions(-) diff --git a/extension/pybindings/pybindings.cpp b/extension/pybindings/pybindings.cpp index 7a9d8c1faf3..56f92356870 100644 --- a/extension/pybindings/pybindings.cpp +++ b/extension/pybindings/pybindings.cpp @@ -661,6 +661,21 @@ struct PyMethodMeta final { } } + size_t num_attributes() const { + return meta_.num_attributes(); + } + + std::unique_ptr attribute_tensor_meta(size_t index) const { + const auto result = meta_.attribute_tensor_meta(index); + THROW_INDEX_IF_ERROR( + result.error(), "Cannot get attribute tensor meta at %zu", index); + if (module_) { + return std::make_unique(module_, result.get()); + } else { + return std::make_unique(state_, result.get()); + } + } + py::str repr() const { py::list input_meta_strs; for (size_t i = 0; i < meta_.num_inputs(); ++i) { @@ -1641,6 +1656,7 @@ PYBIND11_MODULE(EXECUTORCH_PYTHON_MODULE_NAME, m) { .def("name", &PyMethodMeta::name, call_guard) .def("num_inputs", &PyMethodMeta::num_inputs, call_guard) .def("num_outputs", &PyMethodMeta::num_outputs, call_guard) + .def("num_attributes", &PyMethodMeta::num_attributes, call_guard) .def( "input_tensor_meta", &PyMethodMeta::input_tensor_meta, @@ -1651,6 +1667,11 @@ PYBIND11_MODULE(EXECUTORCH_PYTHON_MODULE_NAME, m) { &PyMethodMeta::output_tensor_meta, py::arg("index"), call_guard) + .def( + "attribute_tensor_meta", + &PyMethodMeta::attribute_tensor_meta, + py::arg("index"), + call_guard) .def("__repr__", &PyMethodMeta::repr, call_guard); m.def( diff --git a/extension/pybindings/pybindings.pyi b/extension/pybindings/pybindings.pyi index 7aede1c29a9..1978d22ea96 100644 --- a/extension/pybindings/pybindings.pyi +++ b/extension/pybindings/pybindings.pyi @@ -133,6 +133,10 @@ class MethodMeta: internal buffers""" ... + def num_attributes(self) -> int: + """The number of attribute tensors from the method""" + ... + def input_tensor_meta(self, index: int) -> TensorInfo: """The tensor info for the 'index'th input. Index must be in the interval [0, num_inputs()). Raises an IndexError if the index is out of bounds""" @@ -143,6 +147,11 @@ class MethodMeta: [0, num_outputs()). Raises an IndexError if the index is out of bounds""" ... + def attribute_tensor_meta(self, index: int) -> TensorInfo: + """The tensor info for the 'index'th attribute. Index must be in the interval + [0, num_attributes()). Raises an IndexError if the index is out of bounds""" + ... + def __repr__(self) -> str: ... @experimental("This API is experimental and subject to change without notice.") diff --git a/extension/pybindings/test/test_pybindings.py b/extension/pybindings/test/test_pybindings.py index 95f05bc98f6..8bbdb0d86d4 100644 --- a/extension/pybindings/test/test_pybindings.py +++ b/extension/pybindings/test/test_pybindings.py @@ -518,19 +518,32 @@ def test_method_attribute(self): ) def test_program_method_meta(self) -> None: - exported_program, inputs = create_program(ModuleAdd()) + eager_module = ModuleAddWithAttributes() + inputs = eager_module.get_inputs() + + exported_program = export(eager_module, inputs, strict=True) + exec_prog = to_edge(exported_program).to_executorch( + config=ExecutorchBackendConfig( + emit_mutable_buffer_names=True, + ) + ) + + exec_prog.dump_executorch_program(verbose=True) + + executorch_program = self.load_prog_fn(exec_prog.buffer) - executorch_program = self.load_prog_fn(exported_program.buffer) meta = executorch_program.method_meta("forward") del executorch_program self.assertEqual(meta.name(), "forward") self.assertEqual(meta.num_inputs(), 2) self.assertEqual(meta.num_outputs(), 1) + self.assertEqual(meta.num_attributes(), 1) tensor_info = ( "TensorInfo(sizes=[2, 2], dtype=Float, is_memory_planned=True, nbytes=16)" ) + float_dtype = 6 self.assertEqual( str(meta), @@ -541,10 +554,14 @@ def test_program_method_meta(self) -> None: input_tensors = [meta.input_tensor_meta(i) for i in range(2)] output_tensor = meta.output_tensor_meta(0) + attribute_tensor = meta.attribute_tensor_meta(0) with self.assertRaises(IndexError): meta.input_tensor_meta(2) + with self.assertRaises(IndexError): + meta.attribute_tensor_meta(1) + del meta self.assertEqual([t.sizes() for t in input_tensors], [(2, 2), (2, 2)]) self.assertEqual([t.dtype() for t in input_tensors], [float_dtype, float_dtype]) @@ -558,6 +575,12 @@ def test_program_method_meta(self) -> None: self.assertEqual(output_tensor.nbytes(), 16) self.assertEqual(str(output_tensor), tensor_info) + self.assertEqual(attribute_tensor.sizes(), (2, 2)) + self.assertEqual(attribute_tensor.dtype(), float_dtype) + self.assertEqual(attribute_tensor.is_memory_planned(), True) + self.assertEqual(attribute_tensor.nbytes(), 16) + self.assertEqual(str(attribute_tensor), tensor_info) + def test_method_method_meta(self) -> None: exported_program, inputs = create_program(ModuleAdd())