1212import executorch .backends .vulkan .serialization .vulkan_graph_schema as vk_graph_schema
1313
1414import torch
15+ from executorch .backends .vulkan .utils import (
16+ is_constant ,
17+ is_get_attr_node ,
18+ is_param_node ,
19+ )
1520from executorch .exir .backend .utils import DelegateMappingBuilder
1621
1722from executorch .exir .tensor import TensorSpec
@@ -68,34 +73,12 @@ def get_vk_datatype(torch_dtype: torch.dtype) -> vk_graph_schema.VkDataType:
6873 else :
6974 raise AssertionError (f"Invalid dtype for vulkan_preprocess ({ torch_dtype } )" )
7075
71- def is_constant (self , node : Node ):
72- return (
73- node .name in self .program .graph_signature .inputs_to_lifted_tensor_constants
74- )
75-
76- def is_get_attr_node (self , node : Node ) -> bool :
77- """
78- Returns true if the given node is a get attr node for a tensor of the model
79- """
80- return isinstance (node , Node ) and node .op == "get_attr"
81-
82- def is_param_node (self , node : Node ) -> bool :
83- """
84- Check if the given node is a parameter within the exported program
85- """
86- return (
87- self .is_get_attr_node (node )
88- or is_param (self .program , node )
89- or is_buffer (self .program , node )
90- or self .is_constant (node )
91- )
92-
9376 def get_constant (self , node : Node ) -> Optional [torch .Tensor ]:
9477 """
9578 Returns the constant associated with the given node in the exported program.
9679 Returns None if the node is not a constant within the exported program
9780 """
98- if self . is_constant (node ):
81+ if is_constant (self . program , node ):
9982 constant_name = (
10083 self .program .graph_signature .inputs_to_lifted_tensor_constants [
10184 node .name
@@ -116,9 +99,9 @@ def get_param_tensor(self, node: Node) -> torch.Tensor:
11699 tensor = get_param (self .program , node )
117100 elif is_buffer (self .program , node ):
118101 tensor = get_buffer (self .program , node )
119- elif self . is_constant (node ):
102+ elif is_constant (self . program , node ):
120103 tensor = self .get_constant (node )
121- elif self . is_get_attr_node (node ):
104+ elif is_get_attr_node (node ):
122105 # This is a hack to support both lifted and unlifted graph
123106 try :
124107 tensor = getattr (node .graph .owning_module , node .target )
@@ -132,7 +115,7 @@ def get_param_tensor(self, node: Node) -> torch.Tensor:
132115
133116 def maybe_add_constant_tensor (self , node : Node ) -> int :
134117 constant_id = - 1
135- if self . is_param_node (node ):
118+ if is_param_node (self . program , node ):
136119 constant_id = len (self .const_tensors )
137120 self .const_tensors .append (self .get_param_tensor (node ))
138121
@@ -280,7 +263,7 @@ def process_placeholder_node(self, node: Node) -> None:
280263 if len (node .users ) == 0 :
281264 return None
282265 ids = self .create_node_value (node )
283- if not self . is_param_node (node ):
266+ if not is_param_node (self . program , node ):
284267 if isinstance (ids , int ):
285268 self .input_ids .append (ids )
286269 else :
0 commit comments