1515from executorch .backends .arm .tosa_mapping import TosaArg
1616from executorch .backends .arm .tosa_specification import TosaSpecification
1717from executorch .backends .arm .tosa_utils import getNodeArgs , tosa_shape
18+ from torch ._export .utils import (
19+ get_buffer ,
20+ get_lifted_tensor_constant ,
21+ get_param ,
22+ is_buffer ,
23+ is_lifted_tensor_constant ,
24+ is_param ,
25+ )
1826from torch .export .exported_program import ExportedProgram
1927
2028
@@ -99,8 +107,7 @@ def process_inputs_to_parameters(
99107 f"Failed processing parameter placeholder: { node .name } . "
100108 "Is the original torch function supported?"
101109 ) from e
102- parameter_name = edge_program .graph_signature .inputs_to_parameters [tosa_arg .name ]
103- parameter_data = edge_program .state_dict [parameter_name ]
110+ parameter_data = get_param (edge_program , node )
104111
105112 assert isinstance (parameter_data , torch .Tensor ), "Expect Attr to be tensor"
106113 parameter_values = parameter_data .detach ().numpy ()
@@ -128,8 +135,7 @@ def process_inputs_to_buffers(
128135 f"Failed processing buffer placeholder: { node .name } . "
129136 "Is the original torch function supported?"
130137 ) from e
131- buffer_name = edge_program .graph_signature .inputs_to_buffers [node .name ]
132- buffer_data = edge_program .state_dict [buffer_name ]
138+ buffer_data = get_buffer (edge_program , node )
133139
134140 assert isinstance (buffer_data , torch .Tensor ), "Expect Attr to be tensor"
135141 buffer_values = buffer_data .detach ().numpy ()
@@ -156,11 +162,8 @@ def process_inputs_to_lifted_tensor_constants(
156162 f"Failed processing lifted tensor constant placeholder: { node .name } . "
157163 "Is the original torch function supported?"
158164 ) from e
159- tensor_name = edge_program .graph_signature .inputs_to_lifted_tensor_constants [
160- tosa_arg .name
161- ]
162- tensor = edge_program .tensor_constants [tensor_name ]
163- tensor_data = tensor .detach ().numpy ()
165+ tensor = get_lifted_tensor_constant (edge_program , node )
166+ tensor_data = tensor .detach ().numpy () # type: ignore[union-attr]
164167
165168 tosa_graph .addConst (
166169 tensor_data .shape , tosa_arg .dtype , tensor_data , name = tosa_arg .name
@@ -179,11 +182,11 @@ def process_placeholder(
179182
180183 if node .name in edge_program .graph_signature .user_inputs :
181184 process_inputs (node , tosa_graph , tosa_spec )
182- elif node . name in edge_program . graph_signature . inputs_to_parameters :
185+ elif is_param ( edge_program , node ) :
183186 process_inputs_to_parameters (node , tosa_graph , edge_program , tosa_spec )
184- elif node . name in edge_program . graph_signature . inputs_to_buffers :
187+ elif is_buffer ( edge_program , node ) :
185188 process_inputs_to_buffers (node , tosa_graph , edge_program )
186- elif node . name in edge_program . graph_signature . inputs_to_lifted_tensor_constants :
189+ elif is_lifted_tensor_constant ( edge_program , node ) :
187190 process_inputs_to_lifted_tensor_constants (node , tosa_graph , edge_program )
188191 elif node .name in edge_program .graph_signature .inputs_to_lifted_custom_objs :
189192 raise NotImplementedError (
0 commit comments