Skip to content

Commit 2ece415

Browse files
GregoryComerfacebook-github-bot
authored andcommitted
Add _unsafe_reset_threadpool to pybindings
Summary: Expose thread count setting in pybindings. Currently, pybindings can be very slow on some server machines (observed 300x slower - went from 12s for 100 iterations to under 40ms for 100 iterations). As a stopgap measure, this diff exposes _unsafe_reset_threadpool from python, which can be used to reduce the thread count. This should be done prior to loading a model. Differential Revision: D71023514
1 parent e3c7954 commit 2ece415

File tree

2 files changed

+8
-0
lines changed

2 files changed

+8
-0
lines changed

extension/pybindings/pybindings.cpp

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@
2323
#include <executorch/extension/data_loader/buffer_data_loader.h>
2424
#include <executorch/extension/data_loader/mmap_data_loader.h>
2525
#include <executorch/extension/memory_allocator/malloc_memory_allocator.h>
26+
#include <executorch/extension/threadpool/threadpool.h>
2627
#include <executorch/runtime/backend/interface.h>
2728
#include <executorch/runtime/core/data_loader.h>
2829
#include <executorch/runtime/core/exec_aten/util/scalar_type_util.h>
@@ -1064,6 +1065,12 @@ PYBIND11_MODULE(EXECUTORCH_PYTHON_MODULE_NAME, m) {
10641065
"_reset_profile_results",
10651066
[]() { EXECUTORCH_RESET_PROFILE_RESULTS(); },
10661067
call_guard);
1068+
m.def("_unsafe_reset_threadpool",
1069+
[](int num_threads) {
1070+
executorch::extension::threadpool::get_threadpool()->_unsafe_reset_threadpool(num_threads);
1071+
},
1072+
py::arg("num_threads")
1073+
);
10671074

10681075
py::class_<PyModule>(m, "ExecuTorchModule")
10691076
.def("load_bundled_input", &PyModule::load_bundled_input, call_guard)

shim_et/xplat/executorch/extension/pybindings/pybindings.bzl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -54,6 +54,7 @@ def executorch_pybindings(python_module_name, srcs = [], cppdeps = [], visibilit
5454
],
5555
deps = [
5656
"//executorch/runtime/core:core",
57+
"//executorch/extension/threadpool:threadpool",
5758
] + cppdeps,
5859
external_deps = [
5960
"pybind11",

0 commit comments

Comments
 (0)