66
77# pyre-strict
88
9+ from copy import deepcopy
10+
911import executorch .backends .vulkan .custom_ops_lib # noqa
1012
1113import torch
1214
1315from executorch .backends .vulkan .op_registry import handles_own_prepacking
16+ from executorch .backends .vulkan .utils import is_param_node
1417
1518from executorch .exir .dialects ._ops import ops as exir_ops
1619
17- from torch ._export .utils import is_buffer , is_param
1820from torch .export import ExportedProgram
1921
2022
@@ -29,25 +31,8 @@ def insert_prepack_nodes(program: ExportedProgram) -> ExportedProgram:
2931 argument into the operator implementation.
3032 """
3133
32- def is_get_attr_node (node : torch .fx .Node ) -> bool :
33- return isinstance (node , torch .fx .Node ) and node .op == "get_attr"
34-
35- def is_constant (node : torch .fx .Node ) -> bool :
36- return node .name in program .graph_signature .inputs_to_lifted_tensor_constants
37-
38- def is_param_node (node : torch .fx .Node ) -> bool :
39- """
40- Check if the given node is a parameter within the exported program
41- """
42- return (
43- is_get_attr_node (node )
44- or is_param (program , node )
45- or is_buffer (program , node )
46- or is_constant (node )
47- )
48-
4934 def prepack_not_required (node : torch .fx .Node ) -> bool :
50- if not is_param_node (node ):
35+ if not is_param_node (program , node ):
5136 return True
5237
5338 for user in node .users :
@@ -69,9 +54,15 @@ def prepack_not_required(node: torch.fx.Node) -> bool:
6954 exir_ops .edge .et_vk .prepack .default ,
7055 (node ,),
7156 )
72- prepack_node .meta ["spec" ] = node .meta ["spec" ]
57+ # This pass assumes that the SpecPropPass() has already been applied
58+ assert "spec" in node .meta
59+ # Validate that the original node is marked as a constant. Constant tensors
60+ # do not participate in memory planning.
61+ assert node .meta ["spec" ].const
62+ prepack_node .meta ["val" ] = node .meta ["val" ]
63+ prepack_node .meta ["spec" ] = deepcopy (node .meta ["spec" ])
7364 # Set the mem_obj_id to -1 to indicate that this node requires a dedicated
74- # memory object. This pass must be executed AFTER the memory planning pass.
65+ # memory object.
7566 prepack_node .meta ["spec" ].mem_obj_id = - 1
7667 node .replace_all_uses_with (prepack_node , lambda x , y = prepack_node : x != y )
7768
0 commit comments