Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions extension/pybindings/portable_lib.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@
_dump_profile_results, # noqa: F401
_get_operator_names, # noqa: F401
_get_registered_backend_names, # noqa: F401
_is_available, # noqa: F401
_load_bundled_program_from_buffer, # noqa: F401
_load_for_executorch, # noqa: F401
_load_for_executorch_from_buffer, # noqa: F401
Expand Down
11 changes: 11 additions & 0 deletions extension/pybindings/pybindings.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -88,10 +88,12 @@ using ::executorch::extension::BufferDataLoader;
using ::executorch::extension::MallocMemoryAllocator;
using ::executorch::extension::MmapDataLoader;
using ::executorch::runtime::ArrayRef;
using ::executorch::runtime::BackendInterface;
using ::executorch::runtime::DataLoader;
using ::executorch::runtime::Error;
using ::executorch::runtime::EValue;
using ::executorch::runtime::EventTracerDebugLogLevel;
using ::executorch::runtime::get_backend_class;
using ::executorch::runtime::get_backend_name;
using ::executorch::runtime::get_num_registered_backends;
using ::executorch::runtime::get_registered_kernels;
Expand Down Expand Up @@ -990,6 +992,14 @@ py::list get_registered_backend_names() {
return res;
}

py::bool_ is_available(const std::string& backend_name) {
BackendInterface* backend = get_backend_class(backend_name.c_str());
if (backend == nullptr) {
return false;
}
return backend->is_available();
}

} // namespace

PYBIND11_MODULE(EXECUTORCH_PYTHON_MODULE_NAME, m) {
Expand Down Expand Up @@ -1048,6 +1058,7 @@ PYBIND11_MODULE(EXECUTORCH_PYTHON_MODULE_NAME, m) {
&get_registered_backend_names,
call_guard);
m.def("_get_operator_names", &get_operator_names);
m.def("_is_available", &is_available, py::arg("backend_name"), call_guard);
m.def("_create_profile_block", &create_profile_block, call_guard);
m.def(
"_reset_profile_results",
Expand Down
9 changes: 9 additions & 0 deletions extension/pybindings/pybindings.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -211,6 +211,15 @@ def _load_bundled_program_from_buffer(
"""
...

@experimental("This API is experimental and subject to change without notice.")
def _is_available(backend_name: str) -> bool:
"""
.. warning::

This API is experimental and subject to change without notice.
"""
...

@experimental("This API is experimental and subject to change without notice.")
def _get_operator_names() -> List[str]:
"""
Expand Down
13 changes: 13 additions & 0 deletions extension/pybindings/test/test_backend_pybinding.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,3 +12,16 @@ def test_backend_name_list(
registered_backend_names = runtime.backend_registry.registered_backend_names
self.assertGreaterEqual(len(registered_backend_names), 1)
self.assertIn("XnnpackBackend", registered_backend_names)

def test_backend_is_available(
self,
) -> None:
# XnnpackBackend is available
runtime = Runtime.get()
self.assertTrue(
runtime.backend_registry.is_available(backend_name="XnnpackBackend")
)
# NonExistBackend doesn't exist and not available
self.assertFalse(
runtime.backend_registry.is_available(backend_name="NonExistBackend")
)
6 changes: 6 additions & 0 deletions runtime/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -139,6 +139,12 @@ def registered_backend_names(self) -> List[str]:
"""
return self._legacy_module._get_registered_backend_names()

def is_available(self, backend_name: str) -> bool:
"""
Returns the names of all registered backends as a list of strings.
"""
return self._legacy_module._is_available(backend_name)


class OperatorRegistry:
"""The registry of operators that are available to the runtime."""
Expand Down
Loading