From 530e6d1123f8ee0f446cf2c2a556d85984eb9732 Mon Sep 17 00:00:00 2001 From: Per Held Date: Wed, 12 Mar 2025 10:19:53 +0100 Subject: [PATCH] Arm backend: Enable mypy lintrunner for backends/arm/test Enable the lintrunner to run mypy on backends/arm/test that was ignored in the initial commit of enabling mypy for backends/arm. Mostly sad ignores but also some actual fixes. Change-Id: Iba12558511ae938864086eb0539063a87177c0ae --- .lintrunner.toml | 1 - backends/arm/test/common.py | 16 +++++++--- backends/arm/test/conftest.py | 2 +- backends/arm/test/runner_utils.py | 52 ++++++++++++++++++++----------- backends/arm/test/test_model.py | 2 +- 5 files changed, 46 insertions(+), 27 deletions(-) diff --git a/.lintrunner.toml b/.lintrunner.toml index ed782b12383..b00aa1e683b 100644 --- a/.lintrunner.toml +++ b/.lintrunner.toml @@ -320,7 +320,6 @@ exclude_patterns = [ '**/third-party/**', 'scripts/check_binary_dependencies.py', 'profiler/test/test_profiler_e2e.py', - 'backends/arm/test/**', ] command = [ 'python', diff --git a/backends/arm/test/common.py b/backends/arm/test/common.py index b2114bb0e99..8e3d1f2745b 100644 --- a/backends/arm/test/common.py +++ b/backends/arm/test/common.py @@ -13,14 +13,20 @@ from typing import Any import pytest -from executorch.backends.arm.arm_backend import ArmCompileSpecBuilder -from executorch.backends.arm.test.runner_utils import ( +from executorch.backends.arm.arm_backend import ( # type: ignore[import-not-found] + ArmCompileSpecBuilder, +) +from executorch.backends.arm.test.runner_utils import ( # type: ignore[import-not-found] arm_executor_runner_exists, corstone300_installed, corstone320_installed, ) -from executorch.backends.arm.tosa_specification import TosaSpecification -from executorch.exir.backend.compile_spec_schema import CompileSpec +from executorch.backends.arm.tosa_specification import ( # type: ignore[import-not-found] + TosaSpecification, +) +from executorch.exir.backend.compile_spec_schema import ( # type: ignore[import-not-found] + CompileSpec, +) def get_time_formatted_path(path: str, log_prefix: str) -> str: @@ -193,7 +199,7 @@ def get_u85_compile_spec_unbuilt( def parametrize( - arg_name: str, test_data: dict[str, Any], xfails: dict[str, str] = None + arg_name: str, test_data: dict[str, Any], xfails: dict[str, str] | None = None ): """ Custom version of pytest.mark.parametrize with some syntatic sugar and added xfail functionality diff --git a/backends/arm/test/conftest.py b/backends/arm/test/conftest.py index 66b4b1fc999..ca11733b025 100644 --- a/backends/arm/test/conftest.py +++ b/backends/arm/test/conftest.py @@ -15,7 +15,7 @@ import pytest try: - import tosa_reference_model + import tosa_reference_model # type: ignore[import-untyped] except ImportError: logging.warning("tosa_reference_model not found, can't run reference model tests") tosa_reference_model = None diff --git a/backends/arm/test/runner_utils.py b/backends/arm/test/runner_utils.py index 7dc25a1a4e7..ccbc2eacc1b 100644 --- a/backends/arm/test/runner_utils.py +++ b/backends/arm/test/runner_utils.py @@ -20,22 +20,36 @@ logger = logging.getLogger(__name__) try: - import tosa_reference_model + import tosa_reference_model # type: ignore[import-untyped] except ImportError: tosa_reference_model = None -from executorch.backends.arm.arm_backend import get_tosa_spec, is_tosa - -from executorch.backends.arm.test.conftest import is_option_enabled -from executorch.backends.arm.tosa_specification import TosaSpecification -from executorch.exir import ExecutorchProgramManager, ExportedProgram -from executorch.exir.backend.compile_spec_schema import CompileSpec -from executorch.exir.lowered_backend_module import LoweredBackendModule +from executorch.backends.arm.arm_backend import ( # type: ignore[import-not-found] + get_tosa_spec, + is_tosa, +) + +from executorch.backends.arm.test.conftest import ( # type: ignore[import-not-found] + is_option_enabled, +) +from executorch.backends.arm.tosa_specification import ( # type: ignore[import-not-found] + TosaSpecification, +) +from executorch.exir import ( # type: ignore[import-not-found] + ExecutorchProgramManager, + ExportedProgram, +) +from executorch.exir.backend.compile_spec_schema import ( # type: ignore[import-not-found] + CompileSpec, +) +from executorch.exir.lowered_backend_module import ( # type: ignore[import-not-found] + LoweredBackendModule, +) from packaging.version import Version from torch.fx.node import Node from torch.overrides import TorchFunctionMode from torch.testing._internal.common_utils import torch_to_numpy_dtype_dict -from tosa import TosaGraph +from tosa import TosaGraph # type: ignore[import-untyped] logger = logging.getLogger(__name__) logger.setLevel(logging.CRITICAL) @@ -148,16 +162,16 @@ def get_output_quantization_params( for node in output_nodes: if node.target == torch.ops.quantized_decomposed.dequantize_per_tensor.default: quant_params[node] = QuantizationParams( - node_name=node.args[0].name, - scale=node.args[1], - zp=node.args[2], - qmin=node.args[3], - qmax=node.args[4], - dtype=node.args[5], + node_name=node.args[0].name, # type: ignore + scale=node.args[1], # type: ignore + zp=node.args[2], # type: ignore + qmin=node.args[3], # type: ignore + qmax=node.args[4], # type: ignore + dtype=node.args[5], # type: ignore ) else: - quant_params[node] = None - return quant_params + quant_params[node] = None # type: ignore[assignment] + return quant_params # type: ignore[return-value] class TosaReferenceModelDispatch(TorchFunctionMode): @@ -243,7 +257,7 @@ def run_corstone( input_names = get_input_names(exported_program) input_paths = [] for input_name, input_ in zip(input_names, inputs): - input_path = save_bytes(intermediate_path, input_, input_name) + input_path = save_bytes(intermediate_path.as_posix(), input_, input_name) input_paths.append(input_path) out_path = os.path.join(intermediate_path, "out") @@ -348,7 +362,7 @@ def prep_data_for_save( quant_param: Optional[QuantizationParams] = None, ): if isinstance(data, torch.Tensor): - data_np = np.array(data.detach(), order="C").astype( + data_np = np.array(data.detach(), order="C").astype( # type: ignore[var-annotated] torch_to_numpy_dtype_dict[data.dtype] ) else: diff --git a/backends/arm/test/test_model.py b/backends/arm/test/test_model.py index b94a5f65256..52d73e9dbce 100755 --- a/backends/arm/test/test_model.py +++ b/backends/arm/test/test_model.py @@ -83,7 +83,7 @@ def get_args(): return args -def run_external_cmd(cmd: []): +def run_external_cmd(cmd: list[str]): print("CALL:", *cmd, sep=" ") try: subprocess.check_call(cmd)