1414from  pathlib  import  Path 
1515
1616from  types  import  NoneType 
17- from  typing  import  Any , cast , Dict , List , Literal ,  Optional , Tuple 
17+ from  typing  import  Any , cast , Dict , List , Optional , Tuple 
1818
1919import  numpy  as  np 
2020import  torch 
3737from  torch .fx .node  import  Node 
3838
3939from  torch .overrides  import  TorchFunctionMode 
40- from  tosa .TosaGraph  import  TosaGraph 
40+ from  tosa .TosaGraph  import  TosaGraph    # type: ignore[import-untyped] 
4141
4242logger  =  logging .getLogger (__name__ )
4343
@@ -149,25 +149,28 @@ def get_output_quantization_params(
149149    Raises: 
150150        RuntimeError if no output quantization parameters are found. 
151151    """ 
152-     quant_params  =  {}
153-     for  node  in  output_node .args [0 ]:
154-         if  node .target  ==  torch .ops .quantized_decomposed .dequantize_per_tensor .default :
155-             quant_params [node ] =  QuantizationParams (
156-                 node_name = node .args [0 ].name ,
157-                 scale = node .args [1 ],
158-                 zp = node .args [2 ],
159-                 qmin = node .args [3 ],
160-                 qmax = node .args [4 ],
161-                 dtype = node .args [5 ],
152+     quant_params : dict [Node , QuantizationParams  |  None ] =  {}
153+     for  node  in  output_node .args [0 ]:  # type: ignore[union-attr] 
154+         if  (
155+             node .target   # type: ignore[union-attr] 
156+             ==  torch .ops .quantized_decomposed .dequantize_per_tensor .default 
157+         ):
158+             quant_params [node ] =  QuantizationParams (  # type: ignore[index] 
159+                 node_name = node .args [0 ].name ,  # type: ignore[arg-type, union-attr] 
160+                 scale = node .args [1 ],  # type: ignore[arg-type, union-attr] 
161+                 zp = node .args [2 ],  # type: ignore[arg-type, union-attr] 
162+                 qmin = node .args [3 ],  # type: ignore[arg-type, union-attr] 
163+                 qmax = node .args [4 ],  # type: ignore[arg-type, union-attr] 
164+                 dtype = node .args [5 ],  # type: ignore[arg-type, union-attr] 
162165            )
163166        else :
164-             quant_params [node ] =  None 
167+             quant_params [node ] =  None    # type: ignore[index] 
165168    return  quant_params 
166169
167170
168171def  torch_tensor_to_numpy (tensor : torch .Tensor ) ->  np .ndarray :
169172    dtype  =  _torch_to_numpy_dtype_dict [tensor .dtype ]
170-     array  =  tensor .detach ().numpy ().astype (dtype )
173+     array  =  tensor .detach ().numpy ().astype (dtype )   # type: ignore[var-annotated] 
171174    dim_order  =  tensor .dim_order ()
172175    if  dim_order  ==  NHWC_ORDER :
173176        a  =  array .transpose (NHWC_ORDER )
@@ -252,40 +255,39 @@ def run_target(
252255    executorch_program_manager : ExecutorchProgramManager ,
253256    inputs : Tuple [torch .Tensor ],
254257    intermediate_path : str  |  Path ,
255-     target_board : Literal [ "corestone-300" ,  "corestone-320" ,  "vkml_emulation_layer" ] ,
258+     target_board : str ,
256259    elf_path : str  |  Path ,
257260    timeout : int  =  120 ,  # s 
258261):
259262    if  target_board  not  in   VALID_TARGET :
260263        raise  ValueError (f"Unsupported target: { target_board }  " )
261264
262-     if  target_board  in  ("corstone-300" , "corstone-320" ):
263-         return  run_corstone (
264-             executorch_program_manager ,
265-             inputs ,
266-             intermediate_path ,
267-             target_board ,
268-             elf_path ,
269-             timeout ,
270-         )
271-     elif  target_board  ==  "vkml_emulation_layer" :
265+     if  target_board  ==  "vkml_emulation_layer" :
272266        return  run_vkml_emulation_layer (
273267            executorch_program_manager ,
274268            inputs ,
275269            intermediate_path ,
276270            elf_path ,
277271        )
272+     return  run_corstone (
273+         executorch_program_manager ,
274+         inputs ,
275+         intermediate_path ,
276+         target_board ,
277+         elf_path ,
278+         timeout ,
279+     )
278280
279281
280282def  save_inputs_to_file (
281283    exported_program : ExportedProgram ,
282284    inputs : Tuple [torch .Tensor ],
283285    intermediate_path : str  |  Path ,
284286):
285-     input_file_paths  =  []
287+     input_file_paths :  list [ str ]  =  []
286288    input_names  =  get_input_names (exported_program )
287289    for  input_name , input_  in  zip (input_names , inputs ):
288-         input_path  =  save_bytes (intermediate_path , input_ , input_name )
290+         input_path  =  save_bytes (intermediate_path , input_ , input_name )   # type: ignore[arg-type] 
289291        input_file_paths .append (input_path )
290292
291293    return  input_file_paths 
@@ -298,9 +300,9 @@ def get_output_from_file(
298300):
299301    output_np  =  []
300302    output_node  =  exported_program .graph_module .graph .output_node ()
301-     for  i , node  in  enumerate (output_node .args [0 ]):
303+     for  i , node  in  enumerate (output_node .args [0 ]):   # type: ignore[union-attr] 
302304        output_dtype  =  node .meta ["val" ].dtype 
303-         tosa_ref_output  =  np .fromfile (
305+         tosa_ref_output  =  np .fromfile (   # type: ignore[var-annotated] 
304306            os .path .join (intermediate_path , f"{ output_base_name }  -{ i }  .bin" ),
305307            _torch_to_numpy_dtype_dict [output_dtype ],
306308        )
@@ -362,7 +364,7 @@ def run_corstone(
362364    executorch_program_manager : ExecutorchProgramManager ,
363365    inputs : Tuple [torch .Tensor ],
364366    intermediate_path : str  |  Path ,
365-     target_board : Literal [ "corestone-300" ,  "corestone-320" ] ,
367+     target_board : str ,
366368    elf_path : str  |  Path ,
367369    timeout : int  =  120 ,  # s 
368370) ->  list [torch .Tensor ]:
@@ -749,7 +751,7 @@ def run_tosa_graph(
749751    inputs_np  =  [torch_tensor_to_numpy (input_tensor ) for  input_tensor  in  inputs ]
750752
751753    if  isinstance (tosa_version , Tosa_1_00 ):
752-         import  tosa_reference_model  as  reference_model 
754+         import  tosa_reference_model  as  reference_model    # type: ignore[import-untyped] 
753755
754756        debug_mode  =  "ALL"  if  logger .level  <=  logging .DEBUG  else  None 
755757        outputs_np , status  =  reference_model .run (
@@ -771,7 +773,7 @@ def run_tosa_graph(
771773    # Convert output numpy arrays to tensors with same dim_order as the output nodes 
772774    result  =  [
773775        numpy_to_torch_tensor (output_array , node )
774-         for  output_array , node  in  zip (outputs_np , output_node .args [0 ])
776+         for  output_array , node  in  zip (outputs_np , output_node .args [0 ])   # type: ignore[arg-type] 
775777    ]
776778
777779    return  result 
0 commit comments