Skip to content

Commit 939aabd

Browse files
committed
Update
[ghstack-poisoned]
2 parents f6c43ff + 824aebf commit 939aabd

File tree

8 files changed

+112
-43
lines changed

8 files changed

+112
-43
lines changed

.github/workflows/trunk.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -234,7 +234,7 @@ jobs:
234234
uses: pytorch/test-infra/.github/workflows/macos_job.yml@main
235235
with:
236236
runner: macos-m1-stable
237-
python-version: '3.11'
237+
python-version: "3.11"
238238
submodules: 'true'
239239
ref: ${{ github.event_name == 'pull_request' && github.event.pull_request.head.sha || github.sha }}
240240
script: |

backends/vulkan/_passes/tag_memory_meta_pass.py

Lines changed: 88 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66

77
import logging
88
from copy import deepcopy
9-
from typing import Set
9+
from typing import Any, Set
1010

1111
import executorch.backends.vulkan.utils as utils
1212

@@ -190,20 +190,24 @@ def propose_node_layout(
190190
return next(iter(valid_layouts))
191191

192192
def should_annotate(self, node) -> bool:
193-
if not isinstance(node, torch.fx.Node):
194-
return False
195-
196-
if not utils.is_tensor_node(node):
197-
return False
198-
199-
# Storage type and memory layout for tensorref will be determined at runtime
200-
# so there's no use in setting those attributes ahead of time.
201-
if node.meta.get("vkdg_tensorref", False):
202-
return False
203-
204-
# Skip annotating output node. The output tensors should be annotated by the
205-
# time the output node is observed.
206-
if node.op == "output":
193+
if isinstance(node, torch.fx.Node):
194+
if not utils.is_tensor_node(node):
195+
return False
196+
197+
# Storage type and memory layout for tensorref will be determined at runtime
198+
# so there's no use in setting those attributes ahead of time.
199+
if node.meta.get("vkdg_tensorref", False):
200+
return False
201+
202+
# Skip annotating output node. The output tensors should be annotated by the
203+
# time the output node is observed.
204+
if node.op == "output":
205+
return False
206+
elif isinstance(node, (list, tuple)):
207+
return all(
208+
isinstance(n, torch.fx.Node) and self.should_annotate(n) for n in node
209+
)
210+
else:
207211
return False
208212

209213
return True
@@ -215,6 +219,70 @@ def should_delay_annotation(self, node: torch.fx.Node) -> bool:
215219
# time the prepack node is observed.
216220
return node.target == exir_ops.edge.et_vk.prepack.default
217221

222+
def set_or_transition_arg_node(
223+
self,
224+
i: int,
225+
arg: torch.fx.Node,
226+
node: torch.fx.Node,
227+
graph_module: torch.fx.GraphModule,
228+
dirty: bool,
229+
) -> bool:
230+
assert isinstance(arg, torch.fx.Node)
231+
232+
storage = utils.get_node_storage_type(node)
233+
assert storage is not None
234+
layout = utils.get_node_memory_layout(node)
235+
assert layout is not None
236+
237+
arg_storage = utils.get_node_storage_type(arg)
238+
arg_layout = utils.get_node_memory_layout(arg)
239+
240+
if arg_storage is None:
241+
utils.set_node_spec_attr(arg, "vk_storage_type", storage)
242+
arg_storage = storage
243+
if arg_layout is None:
244+
utils.set_node_spec_attr(arg, "vk_memory_layout", layout)
245+
arg_layout = layout
246+
247+
if arg_storage == storage and arg_layout == layout:
248+
return False
249+
250+
if not dirty:
251+
logger.info(
252+
f"[Vulkan Delegate] Inserting transition(s) for {node.format_node()}:"
253+
)
254+
255+
insert_transition_node(graph_module, node, arg, storage, layout)
256+
257+
logger.info(
258+
f" args {i} ({arg}): ({arg_storage}, {arg_layout}) -> ({storage}, {layout})"
259+
)
260+
261+
return True
262+
263+
def set_or_transition_arg(
264+
self,
265+
i: int,
266+
arg: Any,
267+
node: torch.fx.Node,
268+
graph_module: torch.fx.GraphModule,
269+
dirty: bool,
270+
) -> bool:
271+
if isinstance(arg, torch.fx.Node):
272+
return self.set_or_transition_arg_node(i, arg, node, graph_module, dirty)
273+
elif isinstance(arg, (list, tuple)):
274+
need_transition = False
275+
for arg_node in arg:
276+
need_transition = (
277+
self.set_or_transition_arg_node(
278+
i, arg_node, node, graph_module, need_transition
279+
)
280+
or need_transition
281+
)
282+
return need_transition
283+
else:
284+
return False
285+
218286
# noqa
219287
def call(self, graph_module: torch.fx.GraphModule) -> PassResult:
220288
for node in graph_module.graph.nodes:
@@ -226,36 +294,16 @@ def call(self, graph_module: torch.fx.GraphModule) -> PassResult:
226294

227295
set_memory_metadata(node, storage, layout)
228296

229-
inserting_transitions_for_node = False
297+
need_transition = False
230298
for i, arg in enumerate(node.args):
231299
if not self.should_annotate(arg):
232300
continue
233301

234-
assert isinstance(arg, torch.fx.Node)
235-
236-
arg_storage = utils.get_node_storage_type(arg)
237-
arg_layout = utils.get_node_memory_layout(arg)
238-
239-
if arg_storage is None:
240-
utils.set_node_spec_attr(arg, "vk_storage_type", storage)
241-
arg_storage = storage
242-
if arg_layout is None:
243-
utils.set_node_spec_attr(arg, "vk_memory_layout", layout)
244-
arg_layout = layout
245-
246-
if arg_storage == storage and arg_layout == layout:
247-
continue
248-
249-
if not inserting_transitions_for_node:
250-
inserting_transitions_for_node = True
251-
logger.info(
252-
f"[Vulkan Delegate] Inserting transition(s) for {node.format_node()}:"
302+
need_transition = (
303+
self.set_or_transition_arg(
304+
i, arg, node, graph_module, need_transition
253305
)
254-
255-
insert_transition_node(graph_module, node, arg, storage, layout)
256-
257-
logger.info(
258-
f" args {i} ({arg}): ({arg_storage}, {arg_layout}) -> ({storage}, {layout})"
306+
or need_transition
259307
)
260308

261309
return PassResult(graph_module, True)

build/build_apple_frameworks.sh

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -53,6 +53,7 @@ libmicrokernels-prod.a,\
5353

5454
FRAMEWORK_KERNELS_CUSTOM="kernels_custom:\
5555
libcustom_ops.a,\
56+
libextension_threadpool.a,\
5657
:"
5758

5859
FRAMEWORK_KERNELS_OPTIMIZED="kernels_optimized:\

extension/benchmark/android/benchmark/app/build.gradle.kts

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -14,8 +14,8 @@ android {
1414

1515
defaultConfig {
1616
applicationId = "org.pytorch.minibench"
17-
minSdk = 34
18-
targetSdk = 34
17+
minSdk = 28
18+
targetSdk = 33
1919
versionCode = 1
2020
versionName = "1.0"
2121

extension/pybindings/portable_lib.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,7 @@
4545
_load_for_executorch_from_buffer, # noqa: F401
4646
_load_for_executorch_from_bundled_program, # noqa: F401
4747
_reset_profile_results, # noqa: F401
48+
_unsafe_reset_threadpool, # noqa: F401
4849
BundledModule, # noqa: F401
4950
ExecuTorchModule, # noqa: F401
5051
MethodMeta, # noqa: F401

extension/pybindings/pybindings.cpp

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@
2323
#include <executorch/extension/data_loader/buffer_data_loader.h>
2424
#include <executorch/extension/data_loader/mmap_data_loader.h>
2525
#include <executorch/extension/memory_allocator/malloc_memory_allocator.h>
26+
#include <executorch/extension/threadpool/threadpool.h>
2627
#include <executorch/runtime/backend/interface.h>
2728
#include <executorch/runtime/core/data_loader.h>
2829
#include <executorch/runtime/core/exec_aten/util/scalar_type_util.h>
@@ -1064,6 +1065,14 @@ PYBIND11_MODULE(EXECUTORCH_PYTHON_MODULE_NAME, m) {
10641065
"_reset_profile_results",
10651066
[]() { EXECUTORCH_RESET_PROFILE_RESULTS(); },
10661067
call_guard);
1068+
m.def(
1069+
"_unsafe_reset_threadpool",
1070+
[](int num_threads) {
1071+
executorch::extension::threadpool::get_threadpool()
1072+
->_unsafe_reset_threadpool(num_threads);
1073+
},
1074+
py::arg("num_threads"),
1075+
call_guard);
10671076

10681077
py::class_<PyModule>(m, "ExecuTorchModule")
10691078
.def("load_bundled_input", &PyModule::load_bundled_input, call_guard)

extension/pybindings/pybindings.pyi

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -264,3 +264,12 @@ def _reset_profile_results() -> None:
264264
This API is experimental and subject to change without notice.
265265
"""
266266
...
267+
268+
@experimental("This API is experimental and subject to change without notice.")
269+
def _unsafe_reset_threadpool(num_threads: int) -> None:
270+
"""
271+
.. warning::
272+
273+
This API is experimental and subject to change without notice.
274+
"""
275+
...

shim_et/xplat/executorch/extension/pybindings/pybindings.bzl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -54,6 +54,7 @@ def executorch_pybindings(python_module_name, srcs = [], cppdeps = [], visibilit
5454
],
5555
deps = [
5656
"//executorch/runtime/core:core",
57+
"//executorch/extension/threadpool:threadpool",
5758
] + cppdeps,
5859
external_deps = [
5960
"pybind11",

0 commit comments

Comments
 (0)