88
99from copy import deepcopy
1010
11- import executorch .backends .vulkan .custom_ops_lib # noqa
12-
13- import torch
14-
1511from executorch .backends .vulkan .op_registry import handles_own_prepacking
1612from executorch .backends .vulkan .utils import is_param_node
1713
@@ -31,27 +27,27 @@ def insert_prepack_nodes(program: ExportedProgram) -> ExportedProgram:
3127 argument into the operator implementation.
3228 """
3329
34- def prepack_not_required (node : torch .fx .Node ) -> bool :
30+ for node in program .graph_module .graph .nodes :
31+ # Prepacking is only needed for constant tensors. Only nodes corresponding to
32+ # constant tensors will proceed beyond this point.
3533 if not is_param_node (program , node ):
36- return True
34+ continue
3735
38- # Annotate that this node is going to represented as a tensorref in the Vulkan
39- # compute graph. This will be useful for later graph passes.
36+ # Mark that this node is going to be represented as a TensorRef type in the
37+ # Vulkan compute graph. This annotation is used in later graph passes.
4038 node .meta ["vkdg_tensorref" ] = True
4139
40+ # Get the list of node users that do not handle their own prepacking
41+ nodes_to_replace_input = []
4242 for user in node .users :
43- if user .op == "call_function" and handles_own_prepacking (
44- # pyre-ignore
45- user .target
46- ):
47- return True
43+ if user .op == "call_function" and not handles_own_prepacking (user .target ):
44+ nodes_to_replace_input .append (user )
4845
49- return False
50-
51- for node in program .graph_module .graph .nodes :
52- if prepack_not_required (node ):
46+ if len (nodes_to_replace_input ) == 0 :
5347 continue
5448
49+ replace_all_uses = len (nodes_to_replace_input ) == len (node .users )
50+
5551 with program .graph_module .graph .inserting_after (node ):
5652 prepack_node = program .graph_module .graph .create_node (
5753 "call_function" ,
@@ -74,9 +70,14 @@ def prepack_not_required(node: torch.fx.Node) -> bool:
7470 # Set the mem_obj_id to -1 to indicate that this node requires a dedicated
7571 # memory object.
7672 prepack_node .meta ["spec" ].mem_obj_id = - 1
77- node .replace_all_uses_with (
78- prepack_node , lambda x , y = prepack_node : (x != y and x .op != "output" )
79- )
73+ if replace_all_uses :
74+ node .replace_all_uses_with (
75+ prepack_node ,
76+ lambda x , y = prepack_node : (x != y and x .op != "output" ),
77+ )
78+ else :
79+ for user_node in nodes_to_replace_input :
80+ user_node .replace_input_with (node , prepack_node )
8081
8182 program .graph .eliminate_dead_code ()
8283 return program
0 commit comments