From fdf2dc1dd0bcee4971aa69090a4d4476d16be157 Mon Sep 17 00:00:00 2001 From: Per Held Date: Tue, 7 Oct 2025 12:40:35 +0200 Subject: [PATCH] Arm backend: Fix mypy warnings in test root dir Signed-off-by: per.held@arm.com Change-Id: Ic27ab7b212765d597671b128b0cb3075ab64a548 --- backends/arm/test/common.py | 18 ++++++--- backends/arm/test/runner_utils.py | 66 ++++++++++++++++--------------- backends/arm/test/test_model.py | 3 +- 3 files changed, 48 insertions(+), 39 deletions(-) diff --git a/backends/arm/test/common.py b/backends/arm/test/common.py index 3b5dd8bd4db..d8c7ae1a570 100644 --- a/backends/arm/test/common.py +++ b/backends/arm/test/common.py @@ -10,7 +10,7 @@ from datetime import datetime from pathlib import Path -from typing import Any, Optional +from typing import Any, Callable, Optional, ParamSpec, TypeVar import pytest from executorch.backends.arm.ethosu import EthosUCompileSpec @@ -205,7 +205,7 @@ def get_vgf_compile_spec( ) """Xfails a test if Corsone320 FVP is not installed, or if the executor runner is not built""" -SkipIfNoModelConverter = pytest.mark.skipif( +SkipIfNoModelConverter = pytest.mark.skipif( # type: ignore[call-arg] condition=not (model_converter_installed()), raises=FileNotFoundError, reason="Did not find model-converter on path", @@ -221,6 +221,10 @@ def get_vgf_compile_spec( xfail_type = str | tuple[str, type[Exception]] +_P = ParamSpec("_P") +_R = TypeVar("_R") +Decorator = Callable[[Callable[_P, _R]], Callable[_P, _R]] + def parametrize( arg_name: str, @@ -228,7 +232,7 @@ def parametrize( xfails: dict[str, xfail_type] | None = None, strict: bool = True, flakies: dict[str, int] | None = None, -): +) -> Decorator: """ Custom version of pytest.mark.parametrize with some syntatic sugar and added xfail functionality - test_data is expected as a dict of (id, test_data) pairs @@ -241,7 +245,7 @@ def parametrize( if flakies is None: flakies = {} - def decorator_func(func): + def decorator_func(func: Callable[_P, _R]) -> Callable[_P, _R]: """Test data is transformed from a dict of (id, data) pairs to a list of pytest params to work with the native pytests parametrize function""" pytest_testsuite = [] for id, test_parameters in test_data.items(): @@ -261,14 +265,16 @@ def decorator_func(func): "xfail info needs to be str, or tuple[str, type[Exception]]" ) # Set up our fail marker + marker: tuple[pytest.MarkDecorator, ...] # type: ignore[no-redef] marker = ( pytest.mark.xfail(reason=reason, raises=raises, strict=strict), ) else: - marker = () + marker = () # type: ignore[assignment] pytest_param = pytest.param(test_parameters, id=id, marks=marker) pytest_testsuite.append(pytest_param) - return pytest.mark.parametrize(arg_name, pytest_testsuite)(func) + decorator = pytest.mark.parametrize(arg_name, pytest_testsuite) + return decorator(func) return decorator_func diff --git a/backends/arm/test/runner_utils.py b/backends/arm/test/runner_utils.py index 69d9f838034..9ac488f33e3 100644 --- a/backends/arm/test/runner_utils.py +++ b/backends/arm/test/runner_utils.py @@ -14,7 +14,7 @@ from pathlib import Path from types import NoneType -from typing import Any, cast, Dict, List, Literal, Optional, Tuple +from typing import Any, cast, Dict, List, Optional, Tuple import numpy as np import torch @@ -37,7 +37,7 @@ from torch.fx.node import Node from torch.overrides import TorchFunctionMode -from tosa.TosaGraph import TosaGraph +from tosa.TosaGraph import TosaGraph # type: ignore[import-untyped] logger = logging.getLogger(__name__) @@ -149,25 +149,28 @@ def get_output_quantization_params( Raises: RuntimeError if no output quantization parameters are found. """ - quant_params = {} - for node in output_node.args[0]: - if node.target == torch.ops.quantized_decomposed.dequantize_per_tensor.default: - quant_params[node] = QuantizationParams( - node_name=node.args[0].name, - scale=node.args[1], - zp=node.args[2], - qmin=node.args[3], - qmax=node.args[4], - dtype=node.args[5], + quant_params: dict[Node, QuantizationParams | None] = {} + for node in output_node.args[0]: # type: ignore[union-attr] + if ( + node.target # type: ignore[union-attr] + == torch.ops.quantized_decomposed.dequantize_per_tensor.default + ): + quant_params[node] = QuantizationParams( # type: ignore[index] + node_name=node.args[0].name, # type: ignore[arg-type, union-attr] + scale=node.args[1], # type: ignore[arg-type, union-attr] + zp=node.args[2], # type: ignore[arg-type, union-attr] + qmin=node.args[3], # type: ignore[arg-type, union-attr] + qmax=node.args[4], # type: ignore[arg-type, union-attr] + dtype=node.args[5], # type: ignore[arg-type, union-attr] ) else: - quant_params[node] = None + quant_params[node] = None # type: ignore[index] return quant_params def torch_tensor_to_numpy(tensor: torch.Tensor) -> np.ndarray: dtype = _torch_to_numpy_dtype_dict[tensor.dtype] - array = tensor.detach().numpy().astype(dtype) + array = tensor.detach().numpy().astype(dtype) # type: ignore[var-annotated] dim_order = tensor.dim_order() if dim_order == NHWC_ORDER: a = array.transpose(NHWC_ORDER) @@ -252,29 +255,28 @@ def run_target( executorch_program_manager: ExecutorchProgramManager, inputs: Tuple[torch.Tensor], intermediate_path: str | Path, - target_board: Literal["corestone-300", "corestone-320", "vkml_emulation_layer"], + target_board: str, elf_path: str | Path, timeout: int = 120, # s ): if target_board not in VALID_TARGET: raise ValueError(f"Unsupported target: {target_board}") - if target_board in ("corstone-300", "corstone-320"): - return run_corstone( - executorch_program_manager, - inputs, - intermediate_path, - target_board, - elf_path, - timeout, - ) - elif target_board == "vkml_emulation_layer": + if target_board == "vkml_emulation_layer": return run_vkml_emulation_layer( executorch_program_manager, inputs, intermediate_path, elf_path, ) + return run_corstone( + executorch_program_manager, + inputs, + intermediate_path, + target_board, + elf_path, + timeout, + ) def save_inputs_to_file( @@ -282,10 +284,10 @@ def save_inputs_to_file( inputs: Tuple[torch.Tensor], intermediate_path: str | Path, ): - input_file_paths = [] + input_file_paths: list[str] = [] input_names = get_input_names(exported_program) for input_name, input_ in zip(input_names, inputs): - input_path = save_bytes(intermediate_path, input_, input_name) + input_path = save_bytes(intermediate_path, input_, input_name) # type: ignore[arg-type] input_file_paths.append(input_path) return input_file_paths @@ -298,9 +300,9 @@ def get_output_from_file( ): output_np = [] output_node = exported_program.graph_module.graph.output_node() - for i, node in enumerate(output_node.args[0]): + for i, node in enumerate(output_node.args[0]): # type: ignore[union-attr] output_dtype = node.meta["val"].dtype - tosa_ref_output = np.fromfile( + tosa_ref_output = np.fromfile( # type: ignore[var-annotated] os.path.join(intermediate_path, f"{output_base_name}-{i}.bin"), _torch_to_numpy_dtype_dict[output_dtype], ) @@ -362,7 +364,7 @@ def run_corstone( executorch_program_manager: ExecutorchProgramManager, inputs: Tuple[torch.Tensor], intermediate_path: str | Path, - target_board: Literal["corestone-300", "corestone-320"], + target_board: str, elf_path: str | Path, timeout: int = 120, # s ) -> list[torch.Tensor]: @@ -749,7 +751,7 @@ def run_tosa_graph( inputs_np = [torch_tensor_to_numpy(input_tensor) for input_tensor in inputs] if isinstance(tosa_version, Tosa_1_00): - import tosa_reference_model as reference_model + import tosa_reference_model as reference_model # type: ignore[import-untyped] debug_mode = "ALL" if logger.level <= logging.DEBUG else None outputs_np, status = reference_model.run( @@ -771,7 +773,7 @@ def run_tosa_graph( # Convert output numpy arrays to tensors with same dim_order as the output nodes result = [ numpy_to_torch_tensor(output_array, node) - for output_array, node in zip(outputs_np, output_node.args[0]) + for output_array, node in zip(outputs_np, output_node.args[0]) # type: ignore[arg-type] ] return result diff --git a/backends/arm/test/test_model.py b/backends/arm/test/test_model.py index c336d67ad51..5dc11e12a08 100755 --- a/backends/arm/test/test_model.py +++ b/backends/arm/test/test_model.py @@ -8,6 +8,7 @@ import subprocess import sys import time +from typing import Sequence def get_args(): @@ -96,7 +97,7 @@ def get_args(): return args -def run_external_cmd(cmd: []): +def run_external_cmd(cmd: Sequence[str]) -> None: print("CALL:", *cmd, sep=" ") try: subprocess.check_call(cmd)