Skip to content

Commit 6a7b41c

Browse files
cccclaifacebook-github-bot
authored andcommitted
add backend is available
Summary: Add a pybind API to see if the backend is available Differential Revision: D69810445
1 parent 84273f4 commit 6a7b41c

File tree

5 files changed

+46
-1
lines changed

5 files changed

+46
-1
lines changed

extension/pybindings/portable_lib.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,6 @@
2727
# related libs, ensuring that the pybindings lib can resolve those runtime
2828
# dependencies.
2929
import torch as _torch
30-
3130
# Let users import everything from the C++ _portable_lib extension as if this
3231
# python file defined them. Although we could import these dynamically, it
3332
# wouldn't preserve the static type annotations.
@@ -37,6 +36,7 @@
3736
# Disable "imported but unused" (F401) checks.
3837
_create_profile_block, # noqa: F401
3938
_dump_profile_results, # noqa: F401
39+
_is_available, # noqa: F401
4040
_get_operator_names, # noqa: F401
4141
_get_registered_backend_names, # noqa: F401
4242
_load_bundled_program_from_buffer, # noqa: F401

extension/pybindings/pybindings.cpp

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -88,10 +88,12 @@ using ::executorch::extension::BufferDataLoader;
8888
using ::executorch::extension::MallocMemoryAllocator;
8989
using ::executorch::extension::MmapDataLoader;
9090
using ::executorch::runtime::ArrayRef;
91+
using ::executorch::runtime::BackendInterface;
9192
using ::executorch::runtime::DataLoader;
9293
using ::executorch::runtime::Error;
9394
using ::executorch::runtime::EValue;
9495
using ::executorch::runtime::EventTracerDebugLogLevel;
96+
using ::executorch::runtime::get_backend_class;
9597
using ::executorch::runtime::get_backend_name;
9698
using ::executorch::runtime::get_num_registered_backends;
9799
using ::executorch::runtime::get_registered_kernels;
@@ -990,6 +992,20 @@ py::list get_registered_backend_names() {
990992
return res;
991993
}
992994

995+
py::bool_ is_available(const std::string& target_backend_name) {
996+
size_t n_of_registered_backends = get_num_registered_backends();
997+
for (size_t i = 0; i < n_of_registered_backends; i++) {
998+
auto backend_name_res = get_backend_name(i);
999+
THROW_IF_ERROR(backend_name_res.error(), "Failed to get backend name");
1000+
auto backend_name = backend_name_res.get();
1001+
if(strcmp(backend_name, target_backend_name.c_str()) == 0) {
1002+
BackendInterface* backend = get_backend_class(backend_name);
1003+
return backend->is_available();
1004+
}
1005+
}
1006+
return false;
1007+
}
1008+
9931009
} // namespace
9941010

9951011
PYBIND11_MODULE(EXECUTORCH_PYTHON_MODULE_NAME, m) {
@@ -1048,6 +1064,7 @@ PYBIND11_MODULE(EXECUTORCH_PYTHON_MODULE_NAME, m) {
10481064
&get_registered_backend_names,
10491065
call_guard);
10501066
m.def("_get_operator_names", &get_operator_names);
1067+
m.def("_is_available", &is_available, py::arg("backend_name"), call_guard);
10511068
m.def("_create_profile_block", &create_profile_block, call_guard);
10521069
m.def(
10531070
"_reset_profile_results",

extension/pybindings/pybindings.pyi

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -211,6 +211,17 @@ def _load_bundled_program_from_buffer(
211211
"""
212212
...
213213

214+
215+
@experimental("This API is experimental and subject to change without notice.")
216+
def _is_available(backend_name: str) -> bool:
217+
"""
218+
.. warning::
219+
220+
This API is experimental and subject to change without notice.
221+
"""
222+
...
223+
224+
214225
@experimental("This API is experimental and subject to change without notice.")
215226
def _get_operator_names() -> List[str]:
216227
"""

extension/pybindings/test/test_backend_pybinding.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
import unittest
22

33
from executorch.runtime import Runtime
4+
from executorch.extension.pybindings.portable_lib import _is_available
45

56

67
class TestBackendsPybinding(unittest.TestCase):
@@ -12,3 +13,12 @@ def test_backend_name_list(
1213
registered_backend_names = runtime.backend_registry.registered_backend_names
1314
self.assertGreaterEqual(len(registered_backend_names), 1)
1415
self.assertIn("XnnpackBackend", registered_backend_names)
16+
17+
def test_backend_is_available(
18+
self,
19+
) -> None:
20+
# XnnpackBackend is available
21+
runtime = Runtime.get()
22+
self.assertTrue(runtime.backend_registry.is_available(backend_name="XnnpackBackend"))
23+
# NonExistBackend doesn't exist and not available
24+
self.assertFalse(runtime.backend_registry.is_available(backend_name="NonExistBackend"))

runtime/__init__.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -140,6 +140,13 @@ def registered_backend_names(self) -> List[str]:
140140
return self._legacy_module._get_registered_backend_names()
141141

142142

143+
def is_available(self, backend_name: str) -> bool:
144+
"""
145+
Returns the names of all registered backends as a list of strings.
146+
"""
147+
return self._legacy_module._is_available(backend_name)
148+
149+
143150
class OperatorRegistry:
144151
"""The registry of operators that are available to the runtime."""
145152

0 commit comments

Comments
 (0)