Skip to content

Commit ddb1ef0

Browse files
cccclaifacebook-github-bot
authored andcommitted
add backend is available (pytorch#8738)
Summary: Add a pybind API to see if the backend is available Differential Revision: D69810445
1 parent 88a06ef commit ddb1ef0

File tree

5 files changed

+40
-1
lines changed

5 files changed

+40
-1
lines changed

extension/pybindings/portable_lib.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,6 @@
2727
# related libs, ensuring that the pybindings lib can resolve those runtime
2828
# dependencies.
2929
import torch as _torch
30-
3130
# Let users import everything from the C++ _portable_lib extension as if this
3231
# python file defined them. Although we could import these dynamically, it
3332
# wouldn't preserve the static type annotations.
@@ -37,6 +36,7 @@
3736
# Disable "imported but unused" (F401) checks.
3837
_create_profile_block, # noqa: F401
3938
_dump_profile_results, # noqa: F401
39+
_is_available, # noqa: F401
4040
_get_operator_names, # noqa: F401
4141
_get_registered_backend_names, # noqa: F401
4242
_load_bundled_program_from_buffer, # noqa: F401

extension/pybindings/pybindings.cpp

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -88,10 +88,12 @@ using ::executorch::extension::BufferDataLoader;
8888
using ::executorch::extension::MallocMemoryAllocator;
8989
using ::executorch::extension::MmapDataLoader;
9090
using ::executorch::runtime::ArrayRef;
91+
using ::executorch::runtime::BackendInterface;
9192
using ::executorch::runtime::DataLoader;
9293
using ::executorch::runtime::Error;
9394
using ::executorch::runtime::EValue;
9495
using ::executorch::runtime::EventTracerDebugLogLevel;
96+
using ::executorch::runtime::get_backend_class;
9597
using ::executorch::runtime::get_backend_name;
9698
using ::executorch::runtime::get_num_registered_backends;
9799
using ::executorch::runtime::get_registered_kernels;
@@ -990,6 +992,14 @@ py::list get_registered_backend_names() {
990992
return res;
991993
}
992994

995+
py::bool_ is_available(const std::string& target_backend_name) {
996+
BackendInterface* backend = get_backend_class(target_backend_name.c_str());
997+
if (backend == nullptr) {
998+
return false;
999+
}
1000+
return backend->is_available();
1001+
}
1002+
9931003
} // namespace
9941004

9951005
PYBIND11_MODULE(EXECUTORCH_PYTHON_MODULE_NAME, m) {
@@ -1048,6 +1058,7 @@ PYBIND11_MODULE(EXECUTORCH_PYTHON_MODULE_NAME, m) {
10481058
&get_registered_backend_names,
10491059
call_guard);
10501060
m.def("_get_operator_names", &get_operator_names);
1061+
m.def("_is_available", &is_available, py::arg("backend_name"), call_guard);
10511062
m.def("_create_profile_block", &create_profile_block, call_guard);
10521063
m.def(
10531064
"_reset_profile_results",

extension/pybindings/pybindings.pyi

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -211,6 +211,17 @@ def _load_bundled_program_from_buffer(
211211
"""
212212
...
213213

214+
215+
@experimental("This API is experimental and subject to change without notice.")
216+
def _is_available(backend_name: str) -> bool:
217+
"""
218+
.. warning::
219+
220+
This API is experimental and subject to change without notice.
221+
"""
222+
...
223+
224+
214225
@experimental("This API is experimental and subject to change without notice.")
215226
def _get_operator_names() -> List[str]:
216227
"""

extension/pybindings/test/test_backend_pybinding.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
import unittest
22

33
from executorch.runtime import Runtime
4+
from executorch.extension.pybindings.portable_lib import _is_available
45

56

67
class TestBackendsPybinding(unittest.TestCase):
@@ -12,3 +13,12 @@ def test_backend_name_list(
1213
registered_backend_names = runtime.backend_registry.registered_backend_names
1314
self.assertGreaterEqual(len(registered_backend_names), 1)
1415
self.assertIn("XnnpackBackend", registered_backend_names)
16+
17+
def test_backend_is_available(
18+
self,
19+
) -> None:
20+
# XnnpackBackend is available
21+
runtime = Runtime.get()
22+
self.assertTrue(runtime.backend_registry.is_available(backend_name="XnnpackBackend"))
23+
# NonExistBackend doesn't exist and not available
24+
self.assertFalse(runtime.backend_registry.is_available(backend_name="NonExistBackend"))

runtime/__init__.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -140,6 +140,13 @@ def registered_backend_names(self) -> List[str]:
140140
return self._legacy_module._get_registered_backend_names()
141141

142142

143+
def is_available(self, backend_name: str) -> bool:
144+
"""
145+
Returns the names of all registered backends as a list of strings.
146+
"""
147+
return self._legacy_module._is_available(backend_name)
148+
149+
143150
class OperatorRegistry:
144151
"""The registry of operators that are available to the runtime."""
145152

0 commit comments

Comments
 (0)