1414from pathlib import Path
1515
1616from types import NoneType
17- from typing import Any , cast , Dict , List , Literal , Optional , Tuple
17+ from typing import Any , cast , Dict , List , Optional , Tuple
1818
1919import numpy as np
2020import torch
3737from torch .fx .node import Node
3838
3939from torch .overrides import TorchFunctionMode
40- from tosa .TosaGraph import TosaGraph
40+ from tosa .TosaGraph import TosaGraph # type: ignore[import-untyped]
4141
4242logger = logging .getLogger (__name__ )
4343
@@ -149,25 +149,28 @@ def get_output_quantization_params(
149149 Raises:
150150 RuntimeError if no output quantization parameters are found.
151151 """
152- quant_params = {}
153- for node in output_node .args [0 ]:
154- if node .target == torch .ops .quantized_decomposed .dequantize_per_tensor .default :
155- quant_params [node ] = QuantizationParams (
156- node_name = node .args [0 ].name ,
157- scale = node .args [1 ],
158- zp = node .args [2 ],
159- qmin = node .args [3 ],
160- qmax = node .args [4 ],
161- dtype = node .args [5 ],
152+ quant_params : dict [Node , QuantizationParams | None ] = {}
153+ for node in output_node .args [0 ]: # type: ignore[union-attr]
154+ if (
155+ node .target # type: ignore[union-attr]
156+ == torch .ops .quantized_decomposed .dequantize_per_tensor .default
157+ ):
158+ quant_params [node ] = QuantizationParams ( # type: ignore[index]
159+ node_name = node .args [0 ].name , # type: ignore[arg-type, union-attr]
160+ scale = node .args [1 ], # type: ignore[arg-type, union-attr]
161+ zp = node .args [2 ], # type: ignore[arg-type, union-attr]
162+ qmin = node .args [3 ], # type: ignore[arg-type, union-attr]
163+ qmax = node .args [4 ], # type: ignore[arg-type, union-attr]
164+ dtype = node .args [5 ], # type: ignore[arg-type, union-attr]
162165 )
163166 else :
164- quant_params [node ] = None
167+ quant_params [node ] = None # type: ignore[index]
165168 return quant_params
166169
167170
168171def torch_tensor_to_numpy (tensor : torch .Tensor ) -> np .ndarray :
169172 dtype = _torch_to_numpy_dtype_dict [tensor .dtype ]
170- array = tensor .detach ().numpy ().astype (dtype )
173+ array = tensor .detach ().numpy ().astype (dtype ) # type: ignore[var-annotated]
171174 dim_order = tensor .dim_order ()
172175 if dim_order == NHWC_ORDER :
173176 a = array .transpose (NHWC_ORDER )
@@ -252,40 +255,39 @@ def run_target(
252255 executorch_program_manager : ExecutorchProgramManager ,
253256 inputs : Tuple [torch .Tensor ],
254257 intermediate_path : str | Path ,
255- target_board : Literal [ "corestone-300" , "corestone-320" , "vkml_emulation_layer" ] ,
258+ target_board : str ,
256259 elf_path : str | Path ,
257260 timeout : int = 120 , # s
258261):
259262 if target_board not in VALID_TARGET :
260263 raise ValueError (f"Unsupported target: { target_board } " )
261264
262- if target_board in ("corstone-300" , "corstone-320" ):
263- return run_corstone (
264- executorch_program_manager ,
265- inputs ,
266- intermediate_path ,
267- target_board ,
268- elf_path ,
269- timeout ,
270- )
271- elif target_board == "vkml_emulation_layer" :
265+ if target_board == "vkml_emulation_layer" :
272266 return run_vkml_emulation_layer (
273267 executorch_program_manager ,
274268 inputs ,
275269 intermediate_path ,
276270 elf_path ,
277271 )
272+ return run_corstone (
273+ executorch_program_manager ,
274+ inputs ,
275+ intermediate_path ,
276+ target_board ,
277+ elf_path ,
278+ timeout ,
279+ )
278280
279281
280282def save_inputs_to_file (
281283 exported_program : ExportedProgram ,
282284 inputs : Tuple [torch .Tensor ],
283285 intermediate_path : str | Path ,
284286):
285- input_file_paths = []
287+ input_file_paths : list [ str ] = []
286288 input_names = get_input_names (exported_program )
287289 for input_name , input_ in zip (input_names , inputs ):
288- input_path = save_bytes (intermediate_path , input_ , input_name )
290+ input_path = save_bytes (intermediate_path , input_ , input_name ) # type: ignore[arg-type]
289291 input_file_paths .append (input_path )
290292
291293 return input_file_paths
@@ -298,9 +300,9 @@ def get_output_from_file(
298300):
299301 output_np = []
300302 output_node = exported_program .graph_module .graph .output_node ()
301- for i , node in enumerate (output_node .args [0 ]):
303+ for i , node in enumerate (output_node .args [0 ]): # type: ignore[union-attr]
302304 output_dtype = node .meta ["val" ].dtype
303- tosa_ref_output = np .fromfile (
305+ tosa_ref_output = np .fromfile ( # type: ignore[var-annotated]
304306 os .path .join (intermediate_path , f"{ output_base_name } -{ i } .bin" ),
305307 _torch_to_numpy_dtype_dict [output_dtype ],
306308 )
@@ -362,7 +364,7 @@ def run_corstone(
362364 executorch_program_manager : ExecutorchProgramManager ,
363365 inputs : Tuple [torch .Tensor ],
364366 intermediate_path : str | Path ,
365- target_board : Literal [ "corestone-300" , "corestone-320" ] ,
367+ target_board : str ,
366368 elf_path : str | Path ,
367369 timeout : int = 120 , # s
368370) -> list [torch .Tensor ]:
@@ -759,7 +761,7 @@ def run_tosa_graph(
759761 inputs_np = [torch_tensor_to_numpy (input_tensor ) for input_tensor in inputs ]
760762
761763 if isinstance (tosa_version , Tosa_1_00 ):
762- import tosa_reference_model as reference_model
764+ import tosa_reference_model as reference_model # type: ignore[import-untyped]
763765
764766 debug_mode = "ALL" if logger .level <= logging .DEBUG else None
765767 outputs_np , status = reference_model .run (
@@ -781,7 +783,7 @@ def run_tosa_graph(
781783 # Convert output numpy arrays to tensors with same dim_order as the output nodes
782784 result = [
783785 numpy_to_torch_tensor (output_array , node )
784- for output_array , node in zip (outputs_np , output_node .args [0 ])
786+ for output_array , node in zip (outputs_np , output_node .args [0 ]) # type: ignore[arg-type]
785787 ]
786788
787789 return result
0 commit comments