diff --git a/backends/vulkan/_passes/insert_prepack_nodes.py b/backends/vulkan/_passes/insert_prepack_nodes.py index bf1fc28ba56..ed736438cbb 100644 --- a/backends/vulkan/_passes/insert_prepack_nodes.py +++ b/backends/vulkan/_passes/insert_prepack_nodes.py @@ -8,10 +8,6 @@ from copy import deepcopy -import executorch.backends.vulkan.custom_ops_lib # noqa - -import torch - from executorch.backends.vulkan.op_registry import handles_own_prepacking from executorch.backends.vulkan.utils import is_param_node @@ -31,27 +27,27 @@ def insert_prepack_nodes(program: ExportedProgram) -> ExportedProgram: argument into the operator implementation. """ - def prepack_not_required(node: torch.fx.Node) -> bool: + for node in program.graph_module.graph.nodes: + # Prepacking is only needed for constant tensors. Only nodes corresponding to + # constant tensors will proceed beyond this point. if not is_param_node(program, node): - return True + continue - # Annotate that this node is going to represented as a tensorref in the Vulkan - # compute graph. This will be useful for later graph passes. + # Mark that this node is going to be represented as a TensorRef type in the + # Vulkan compute graph. This annotation is used in later graph passes. node.meta["vkdg_tensorref"] = True + # Get the list of node users that do not handle their own prepacking + nodes_to_replace_input = [] for user in node.users: - if user.op == "call_function" and handles_own_prepacking( - # pyre-ignore - user.target - ): - return True + if user.op == "call_function" and not handles_own_prepacking(user.target): + nodes_to_replace_input.append(user) - return False - - for node in program.graph_module.graph.nodes: - if prepack_not_required(node): + if len(nodes_to_replace_input) == 0: continue + replace_all_uses = len(nodes_to_replace_input) == len(node.users) + with program.graph_module.graph.inserting_after(node): prepack_node = program.graph_module.graph.create_node( "call_function", @@ -74,9 +70,14 @@ def prepack_not_required(node: torch.fx.Node) -> bool: # Set the mem_obj_id to -1 to indicate that this node requires a dedicated # memory object. prepack_node.meta["spec"].mem_obj_id = -1 - node.replace_all_uses_with( - prepack_node, lambda x, y=prepack_node: (x != y and x.op != "output") - ) + if replace_all_uses: + node.replace_all_uses_with( + prepack_node, + lambda x, y=prepack_node: (x != y and x.op != "output"), + ) + else: + for user_node in nodes_to_replace_input: + user_node.replace_input_with(node, prepack_node) program.graph.eliminate_dead_code() return program