From 850d15ac2383d06e68e142f54360d48937aefbc2 Mon Sep 17 00:00:00 2001 From: Stephen Jia Date: Tue, 11 Mar 2025 16:49:45 -0700 Subject: [PATCH] [ET-VK] Allow memory tagging pass to handle nodes with list of tensor args ## Context Currently, the memory metadata tagging pass cannot properly insert memory type transitions for operators that accept a list of tensors as an input argument. This diff addresses the gap by implementing the case where a graph node's argument is a list of nodes instead of just a single node. Differential Revision: [D71005186](https://our.internmc.facebook.com/intern/diff/D71005186/) ghstack-source-id: 271146751 Pull Request resolved: https://github.com/pytorch/executorch/pull/9173 --- .../vulkan/_passes/tag_memory_meta_pass.py | 128 ++++++++++++------ 1 file changed, 88 insertions(+), 40 deletions(-) diff --git a/backends/vulkan/_passes/tag_memory_meta_pass.py b/backends/vulkan/_passes/tag_memory_meta_pass.py index f2f54404ca8..03721066f1c 100644 --- a/backends/vulkan/_passes/tag_memory_meta_pass.py +++ b/backends/vulkan/_passes/tag_memory_meta_pass.py @@ -6,7 +6,7 @@ import logging from copy import deepcopy -from typing import Set +from typing import Any, Set import executorch.backends.vulkan.utils as utils @@ -190,20 +190,24 @@ def propose_node_layout( return next(iter(valid_layouts)) def should_annotate(self, node) -> bool: - if not isinstance(node, torch.fx.Node): - return False - - if not utils.is_tensor_node(node): - return False - - # Storage type and memory layout for tensorref will be determined at runtime - # so there's no use in setting those attributes ahead of time. - if node.meta.get("vkdg_tensorref", False): - return False - - # Skip annotating output node. The output tensors should be annotated by the - # time the output node is observed. - if node.op == "output": + if isinstance(node, torch.fx.Node): + if not utils.is_tensor_node(node): + return False + + # Storage type and memory layout for tensorref will be determined at runtime + # so there's no use in setting those attributes ahead of time. + if node.meta.get("vkdg_tensorref", False): + return False + + # Skip annotating output node. The output tensors should be annotated by the + # time the output node is observed. + if node.op == "output": + return False + elif isinstance(node, (list, tuple)): + return all( + isinstance(n, torch.fx.Node) and self.should_annotate(n) for n in node + ) + else: return False return True @@ -215,6 +219,70 @@ def should_delay_annotation(self, node: torch.fx.Node) -> bool: # time the prepack node is observed. return node.target == exir_ops.edge.et_vk.prepack.default + def set_or_transition_arg_node( + self, + i: int, + arg: torch.fx.Node, + node: torch.fx.Node, + graph_module: torch.fx.GraphModule, + dirty: bool, + ) -> bool: + assert isinstance(arg, torch.fx.Node) + + storage = utils.get_node_storage_type(node) + assert storage is not None + layout = utils.get_node_memory_layout(node) + assert layout is not None + + arg_storage = utils.get_node_storage_type(arg) + arg_layout = utils.get_node_memory_layout(arg) + + if arg_storage is None: + utils.set_node_spec_attr(arg, "vk_storage_type", storage) + arg_storage = storage + if arg_layout is None: + utils.set_node_spec_attr(arg, "vk_memory_layout", layout) + arg_layout = layout + + if arg_storage == storage and arg_layout == layout: + return False + + if not dirty: + logger.info( + f"[Vulkan Delegate] Inserting transition(s) for {node.format_node()}:" + ) + + insert_transition_node(graph_module, node, arg, storage, layout) + + logger.info( + f" args {i} ({arg}): ({arg_storage}, {arg_layout}) -> ({storage}, {layout})" + ) + + return True + + def set_or_transition_arg( + self, + i: int, + arg: Any, + node: torch.fx.Node, + graph_module: torch.fx.GraphModule, + dirty: bool, + ) -> bool: + if isinstance(arg, torch.fx.Node): + return self.set_or_transition_arg_node(i, arg, node, graph_module, dirty) + elif isinstance(arg, (list, tuple)): + need_transition = False + for arg_node in arg: + need_transition = ( + self.set_or_transition_arg_node( + i, arg_node, node, graph_module, need_transition + ) + or need_transition + ) + return need_transition + else: + return False + # noqa def call(self, graph_module: torch.fx.GraphModule) -> PassResult: for node in graph_module.graph.nodes: @@ -226,36 +294,16 @@ def call(self, graph_module: torch.fx.GraphModule) -> PassResult: set_memory_metadata(node, storage, layout) - inserting_transitions_for_node = False + need_transition = False for i, arg in enumerate(node.args): if not self.should_annotate(arg): continue - assert isinstance(arg, torch.fx.Node) - - arg_storage = utils.get_node_storage_type(arg) - arg_layout = utils.get_node_memory_layout(arg) - - if arg_storage is None: - utils.set_node_spec_attr(arg, "vk_storage_type", storage) - arg_storage = storage - if arg_layout is None: - utils.set_node_spec_attr(arg, "vk_memory_layout", layout) - arg_layout = layout - - if arg_storage == storage and arg_layout == layout: - continue - - if not inserting_transitions_for_node: - inserting_transitions_for_node = True - logger.info( - f"[Vulkan Delegate] Inserting transition(s) for {node.format_node()}:" + need_transition = ( + self.set_or_transition_arg( + i, arg, node, graph_module, need_transition ) - - insert_transition_node(graph_module, node, arg, storage, layout) - - logger.info( - f" args {i} ({arg}): ({arg_storage}, {arg_layout}) -> ({storage}, {layout})" + or need_transition ) return PassResult(graph_module, True)