@@ -115,50 +115,53 @@ def _get_input_quantization_params(
115115 return quant_params
116116
117117
118- def _get_output_node (program : ExportedProgram ) -> Node :
118+ def _get_output_nodes (program : ExportedProgram ) -> list [ Node ] :
119119 """
120120 Get output node to this model.
121121
122122 Args:
123- program (ExportedProgram): The program to get output node from.
123+ program (ExportedProgram): The program to get the output nodes from.
124124 Returns:
125- The node that is the output of 'program'.
125+ The nodes that are the outputs of the 'program'.
126126 """
127-
127+ output_nodes = []
128128 for node in program .graph .nodes :
129129 if node .op == "output" :
130- return node
131- raise RuntimeError ("No output node found." )
130+ for output in node .args [0 ]:
131+ output_nodes .append (output )
132+ if len (output_nodes ) == 0 :
133+ raise RuntimeError ("No output nodes found." )
134+ else :
135+ return output_nodes
132136
133137
134138def _get_output_quantization_params (
135- program : ExportedProgram , output_node : Node
136- ) -> Optional [QuantizationParams ]:
139+ output_nodes : list [ Node ],
140+ ) -> List [QuantizationParams ]:
137141 """
138142 Get output QuantizationParams from a program.
139143 Args:
140- program (ExportedProgram): The program to get output quantization parameters from.
144+ output_nodes (list(Node)): A list of output nodes to get output quantization parameters from.
141145 Returns:
142146 QuantizationParams: The found quantization parameters.
143147 Raises:
144148 RuntimeError if no output quantization parameters are found.
145149 """
146-
147- quant_params = None
148- for node in program .graph .nodes :
149- if (
150- node .target == torch .ops .quantized_decomposed .dequantize_per_tensor .default
151- and node == output_node .args [0 ][0 ]
152- ):
153- quant_params = QuantizationParams (
154- node_name = node .args [0 ].name ,
155- scale = node .args [1 ],
156- zp = node .args [2 ],
157- qmin = node .args [3 ],
158- qmax = node .args [4 ],
159- dtype = node .args [5 ],
150+ quant_params = []
151+ for node in output_nodes :
152+ if node .target == torch .ops .quantized_decomposed .dequantize_per_tensor .default :
153+ quant_params .append (
154+ QuantizationParams (
155+ node_name = node .args [0 ].name ,
156+ scale = node .args [1 ],
157+ zp = node .args [2 ],
158+ qmin = node .args [3 ],
159+ qmax = node .args [4 ],
160+ dtype = node .args [5 ],
161+ )
160162 )
161- break # break early, there's only one output node
163+ if len (quant_params ) == 0 :
164+ raise RuntimeError ("No Quantization parameters not found in exported model." )
162165 return quant_params
163166
164167
@@ -211,7 +214,7 @@ def __init__(
211214 self .input_names : list [str ] = None
212215 self .output_name : str = None
213216 self .qp_input : list [QuantizationParams ] = None
214- self .qp_output : QuantizationParams = None
217+ self .qp_output : list [ QuantizationParams ] = None
215218 self .timeout = 480
216219 self .target_board : str = None
217220
@@ -226,19 +229,17 @@ def init_run(
226229 ):
227230
228231 self .input_names = _get_input_names (edge_program )
229- self .output_node = _get_output_node (exported_program )
230- self . output_name = self . output_node . name
232+ self .output_nodes = _get_output_nodes (exported_program )
233+
231234 self .is_quantized = is_quantized
232235 self .target_board = target_board
233236
234237 if is_quantized :
235238 self .qp_input = _get_input_quantization_params (exported_program )
236- self .qp_output = _get_output_quantization_params (
237- exported_program , self .output_node
238- )
239+ self .qp_output = _get_output_quantization_params (self .output_nodes )
239240 else :
240241 self .qp_input = [None ] * len (self .input_names )
241- self .qp_output = None
242+ self .qp_output = [ None ] * len ( self . output_nodes )
242243
243244 self ._has_init_run = True
244245
@@ -265,7 +266,7 @@ def run_corstone(
265266 save_bytes (self .intermediate_path , data , False , input_name , quant_param )
266267
267268 out_path = os .path .join (self .intermediate_path , "out" )
268- out_path_with_suffix = out_path + "-0.bin"
269+
269270 input_paths = []
270271 for name in self .input_names :
271272 input_paths .append (
@@ -281,6 +282,7 @@ def run_corstone(
281282 ), f"Did not find build arm_executor_runner in path { elf_path } , run setup_testing.sh?"
282283
283284 cmd_line = f"executor_runner -m { pte_path } -o { out_path } "
285+
284286 for input_path in input_paths :
285287 cmd_line += f" -i { input_path } "
286288
@@ -362,11 +364,14 @@ def run_corstone(
362364 raise RuntimeError (
363365 f"Corstone simulation failed:\n cmd: { command_args [self .target_board ]} \n , log: \n { result_stdout } \n { result .stderr .decode ()} "
364366 )
365-
366- tosa_ref_output = np .fromfile (out_path_with_suffix , dtype = np .float32 )
367- output_shape = self .output_node .args [0 ][0 ].meta ["val" ].shape
368- tosa_ref_output = torch .from_numpy (tosa_ref_output ).reshape (output_shape )
369- return tosa_ref_output
367+ output_np = []
368+ for i , node in enumerate (self .output_nodes ):
369+ tosa_ref_output = np .fromfile (
370+ os .path .join (self .intermediate_path , f"out-{ i } .bin" ), dtype = np .float32
371+ )
372+ output_shape = node .meta ["val" ].shape
373+ output_np .append (torch .from_numpy (tosa_ref_output ).reshape (output_shape ))
374+ return tuple (output_np )
370375
371376 def run_tosa_graph (
372377 self , graph : TosaGraph , inputs : list [np .ndarray ] | list [torch .Tensor ]
0 commit comments