1414from executorch .backends .arm .operators .node_visitor import NodeVisitor
1515from executorch .backends .arm .tosa_mapping import TosaArg
1616from executorch .backends .arm .tosa_specification import TosaSpecification
17- from executorch .backends .arm .tosa_utils import getNodeArgs , tosa_shape
17+ from executorch .backends .arm .tosa_utils import (
18+ get_node_debug_info ,
19+ getNodeArgs ,
20+ tosa_shape ,
21+ )
1822from torch .export .exported_program import ExportedProgram
1923
2024
@@ -28,8 +32,13 @@ def process_call_function(
2832 inputs = getNodeArgs (node )
2933
3034 # Convert output (this node itself)
31- output = TosaArg (node )
32-
35+ try :
36+ output = TosaArg (node )
37+ except ValueError as e :
38+ raise ValueError (
39+ f"Failed processing call_function:\n { get_node_debug_info (node )} "
40+ "Is the original torch function supported?"
41+ ) from e
3342 tosa_graph .currRegion .currBasicBlock .addTensor (
3443 output .name , tosa_shape (output .shape , output .dim_order ), output .dtype
3544 )
@@ -61,15 +70,21 @@ def process_inputs(
6170 f"Arm backend only supports contiguous memory format for inputs. "
6271 f"Expected dim_order: { tuple (range (meta .dim ()))} , but got: { meta .dim_order ()} for node { node .name } "
6372 )
64- inputs = [TosaArg (node )]
65- input_shape = inputs [0 ].shape
66- input_dim_order = inputs [0 ].dim_order
73+ try :
74+ tosa_arg = TosaArg (node )
75+ except ValueError as e :
76+ raise ValueError (
77+ f"Failed processing input placeholder:\n { get_node_debug_info (node )} "
78+ "Is the original torch function supported?"
79+ ) from e
80+ input_shape = tosa_arg .shape
81+ input_dim_order = tosa_arg .dim_order
6782 tensor = ts .TosaSerializerTensor (
68- inputs [ 0 ] .name ,
83+ tosa_arg .name ,
6984 tosa_shape (input_shape , input_dim_order ),
70- inputs [ 0 ] .dtype ,
85+ tosa_arg .dtype ,
7186 data = None ,
72- placeholderFilename = inputs [ 0 ] .name + ".npy" ,
87+ placeholderFilename = tosa_arg .name + ".npy" ,
7388 )
7489 tosa_graph .addInputTensor (tensor )
7590
@@ -81,20 +96,26 @@ def process_inputs_to_parameters(
8196 tosa_spec : TosaSpecification ,
8297):
8398 """Serialize bias and non-quantized weights"""
84- inputs = [TosaArg (node )]
85- parameter_name = edge_program .graph_signature .inputs_to_parameters [node .name ]
99+ try :
100+ tosa_arg = TosaArg (node )
101+ except ValueError as e :
102+ raise ValueError (
103+ f"Failed processing parameter placeholder:\n { get_node_debug_info (node )} "
104+ "Is the original torch function supported?"
105+ ) from e
106+ parameter_name = edge_program .graph_signature .inputs_to_parameters [tosa_arg .name ]
86107 parameter_data = edge_program .state_dict [parameter_name ]
87108
88109 assert isinstance (parameter_data , torch .Tensor ), "Expect Attr to be tensor"
89110 parameter_values = parameter_data .detach ().numpy ()
90111
91- if inputs [ 0 ] .dtype == torch .float32 :
112+ if tosa_arg .dtype == torch .float32 :
92113 assert tosa_spec .support_float (), f"{ tosa_spec } doesn't support float"
93114
94- parameter_values = np .transpose (parameter_values , inputs [ 0 ] .dim_order )
115+ parameter_values = np .transpose (parameter_values , tosa_arg .dim_order )
95116
96117 tosa_graph .addConst (
97- parameter_values .shape , inputs [ 0 ] .dtype , parameter_values , name = node .name
118+ parameter_values .shape , tosa_arg .dtype , parameter_values , name = tosa_arg .name
98119 )
99120
100121
@@ -104,7 +125,13 @@ def process_inputs_to_buffers(
104125 edge_program : ExportedProgram ,
105126):
106127 """Serialize quantized weights"""
107- inputs = [TosaArg (node )]
128+ try :
129+ tosa_arg = TosaArg (node )
130+ except ValueError as e :
131+ raise ValueError (
132+ f"Failed processing buffer placeholder:\n { get_node_debug_info (node )} "
133+ "Is the original torch function supported?"
134+ ) from e
108135 buffer_name = edge_program .graph_signature .inputs_to_buffers [node .name ]
109136 buffer_data = edge_program .state_dict [buffer_name ]
110137
@@ -114,10 +141,10 @@ def process_inputs_to_buffers(
114141 # TODO: fragile code for temporary fix
115142 # the mean and var tensors are also stored here but they have shape (1, )
116143 # we only transpose weights here
117- buffer_values = np .transpose (buffer_values , inputs [ 0 ] .dim_order )
144+ buffer_values = np .transpose (buffer_values , tosa_arg .dim_order )
118145
119146 tosa_graph .addConst (
120- buffer_values .shape , inputs [ 0 ] .dtype , buffer_values , name = node .name
147+ buffer_values .shape , tosa_arg .dtype , buffer_values , name = node .name
121148 )
122149
123150
@@ -126,14 +153,22 @@ def process_inputs_to_lifted_tensor_constants(
126153 tosa_graph : ts .TosaSerializer ,
127154 edge_program : ExportedProgram ,
128155):
129- arg = TosaArg (node )
156+ try :
157+ tosa_arg = TosaArg (node )
158+ except ValueError as e :
159+ raise ValueError (
160+ f"Failed processing lifted tensor constant placeholder:\n { get_node_debug_info (node )} "
161+ "Is the original torch function supported?"
162+ ) from e
130163 tensor_name = edge_program .graph_signature .inputs_to_lifted_tensor_constants [
131- arg .name
164+ tosa_arg .name
132165 ]
133166 tensor = edge_program .tensor_constants [tensor_name ]
134167 tensor_data = tensor .detach ().numpy ()
135168
136- tosa_graph .addConst (tensor_data .shape , arg .dtype , tensor_data , name = arg .name )
169+ tosa_graph .addConst (
170+ tensor_data .shape , tosa_arg .dtype , tensor_data , name = tosa_arg .name
171+ )
137172
138173
139174def process_placeholder (
0 commit comments