Skip to content

Commit ed91b6a

Browse files
authored
Arm backend: Fix mypy warnings in test root dir (pytorch#15365)
Signed-off-by: [email protected]
1 parent 6abe901 commit ed91b6a

File tree

3 files changed

+48
-39
lines changed

3 files changed

+48
-39
lines changed

backends/arm/test/common.py

Lines changed: 12 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010
from datetime import datetime
1111

1212
from pathlib import Path
13-
from typing import Any, Optional
13+
from typing import Any, Callable, Optional, ParamSpec, TypeVar
1414

1515
import pytest
1616
from executorch.backends.arm.ethosu import EthosUCompileSpec
@@ -205,7 +205,7 @@ def get_vgf_compile_spec(
205205
)
206206
"""Xfails a test if Corsone320 FVP is not installed, or if the executor runner is not built"""
207207

208-
SkipIfNoModelConverter = pytest.mark.skipif(
208+
SkipIfNoModelConverter = pytest.mark.skipif( # type: ignore[call-arg]
209209
condition=not (model_converter_installed()),
210210
raises=FileNotFoundError,
211211
reason="Did not find model-converter on path",
@@ -221,14 +221,18 @@ def get_vgf_compile_spec(
221221

222222
xfail_type = str | tuple[str, type[Exception]]
223223

224+
_P = ParamSpec("_P")
225+
_R = TypeVar("_R")
226+
Decorator = Callable[[Callable[_P, _R]], Callable[_P, _R]]
227+
224228

225229
def parametrize(
226230
arg_name: str,
227231
test_data: dict[str, Any],
228232
xfails: dict[str, xfail_type] | None = None,
229233
strict: bool = True,
230234
flakies: dict[str, int] | None = None,
231-
):
235+
) -> Decorator:
232236
"""
233237
Custom version of pytest.mark.parametrize with some syntatic sugar and added xfail functionality
234238
- test_data is expected as a dict of (id, test_data) pairs
@@ -241,7 +245,7 @@ def parametrize(
241245
if flakies is None:
242246
flakies = {}
243247

244-
def decorator_func(func):
248+
def decorator_func(func: Callable[_P, _R]) -> Callable[_P, _R]:
245249
"""Test data is transformed from a dict of (id, data) pairs to a list of pytest params to work with the native pytests parametrize function"""
246250
pytest_testsuite = []
247251
for id, test_parameters in test_data.items():
@@ -261,14 +265,16 @@ def decorator_func(func):
261265
"xfail info needs to be str, or tuple[str, type[Exception]]"
262266
)
263267
# Set up our fail marker
268+
marker: tuple[pytest.MarkDecorator, ...] # type: ignore[no-redef]
264269
marker = (
265270
pytest.mark.xfail(reason=reason, raises=raises, strict=strict),
266271
)
267272
else:
268-
marker = ()
273+
marker = () # type: ignore[assignment]
269274

270275
pytest_param = pytest.param(test_parameters, id=id, marks=marker)
271276
pytest_testsuite.append(pytest_param)
272-
return pytest.mark.parametrize(arg_name, pytest_testsuite)(func)
277+
decorator = pytest.mark.parametrize(arg_name, pytest_testsuite)
278+
return decorator(func)
273279

274280
return decorator_func

backends/arm/test/runner_utils.py

Lines changed: 34 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414
from pathlib import Path
1515

1616
from types import NoneType
17-
from typing import Any, cast, Dict, List, Literal, Optional, Tuple
17+
from typing import Any, cast, Dict, List, Optional, Tuple
1818

1919
import numpy as np
2020
import torch
@@ -37,7 +37,7 @@
3737
from torch.fx.node import Node
3838

3939
from torch.overrides import TorchFunctionMode
40-
from tosa.TosaGraph import TosaGraph
40+
from tosa.TosaGraph import TosaGraph # type: ignore[import-untyped]
4141

4242
logger = logging.getLogger(__name__)
4343

@@ -149,25 +149,28 @@ def get_output_quantization_params(
149149
Raises:
150150
RuntimeError if no output quantization parameters are found.
151151
"""
152-
quant_params = {}
153-
for node in output_node.args[0]:
154-
if node.target == torch.ops.quantized_decomposed.dequantize_per_tensor.default:
155-
quant_params[node] = QuantizationParams(
156-
node_name=node.args[0].name,
157-
scale=node.args[1],
158-
zp=node.args[2],
159-
qmin=node.args[3],
160-
qmax=node.args[4],
161-
dtype=node.args[5],
152+
quant_params: dict[Node, QuantizationParams | None] = {}
153+
for node in output_node.args[0]: # type: ignore[union-attr]
154+
if (
155+
node.target # type: ignore[union-attr]
156+
== torch.ops.quantized_decomposed.dequantize_per_tensor.default
157+
):
158+
quant_params[node] = QuantizationParams( # type: ignore[index]
159+
node_name=node.args[0].name, # type: ignore[arg-type, union-attr]
160+
scale=node.args[1], # type: ignore[arg-type, union-attr]
161+
zp=node.args[2], # type: ignore[arg-type, union-attr]
162+
qmin=node.args[3], # type: ignore[arg-type, union-attr]
163+
qmax=node.args[4], # type: ignore[arg-type, union-attr]
164+
dtype=node.args[5], # type: ignore[arg-type, union-attr]
162165
)
163166
else:
164-
quant_params[node] = None
167+
quant_params[node] = None # type: ignore[index]
165168
return quant_params
166169

167170

168171
def torch_tensor_to_numpy(tensor: torch.Tensor) -> np.ndarray:
169172
dtype = _torch_to_numpy_dtype_dict[tensor.dtype]
170-
array = tensor.detach().numpy().astype(dtype)
173+
array = tensor.detach().numpy().astype(dtype) # type: ignore[var-annotated]
171174
dim_order = tensor.dim_order()
172175
if dim_order == NHWC_ORDER:
173176
a = array.transpose(NHWC_ORDER)
@@ -252,40 +255,39 @@ def run_target(
252255
executorch_program_manager: ExecutorchProgramManager,
253256
inputs: Tuple[torch.Tensor],
254257
intermediate_path: str | Path,
255-
target_board: Literal["corestone-300", "corestone-320", "vkml_emulation_layer"],
258+
target_board: str,
256259
elf_path: str | Path,
257260
timeout: int = 120, # s
258261
):
259262
if target_board not in VALID_TARGET:
260263
raise ValueError(f"Unsupported target: {target_board}")
261264

262-
if target_board in ("corstone-300", "corstone-320"):
263-
return run_corstone(
264-
executorch_program_manager,
265-
inputs,
266-
intermediate_path,
267-
target_board,
268-
elf_path,
269-
timeout,
270-
)
271-
elif target_board == "vkml_emulation_layer":
265+
if target_board == "vkml_emulation_layer":
272266
return run_vkml_emulation_layer(
273267
executorch_program_manager,
274268
inputs,
275269
intermediate_path,
276270
elf_path,
277271
)
272+
return run_corstone(
273+
executorch_program_manager,
274+
inputs,
275+
intermediate_path,
276+
target_board,
277+
elf_path,
278+
timeout,
279+
)
278280

279281

280282
def save_inputs_to_file(
281283
exported_program: ExportedProgram,
282284
inputs: Tuple[torch.Tensor],
283285
intermediate_path: str | Path,
284286
):
285-
input_file_paths = []
287+
input_file_paths: list[str] = []
286288
input_names = get_input_names(exported_program)
287289
for input_name, input_ in zip(input_names, inputs):
288-
input_path = save_bytes(intermediate_path, input_, input_name)
290+
input_path = save_bytes(intermediate_path, input_, input_name) # type: ignore[arg-type]
289291
input_file_paths.append(input_path)
290292

291293
return input_file_paths
@@ -298,9 +300,9 @@ def get_output_from_file(
298300
):
299301
output_np = []
300302
output_node = exported_program.graph_module.graph.output_node()
301-
for i, node in enumerate(output_node.args[0]):
303+
for i, node in enumerate(output_node.args[0]): # type: ignore[union-attr]
302304
output_dtype = node.meta["val"].dtype
303-
tosa_ref_output = np.fromfile(
305+
tosa_ref_output = np.fromfile( # type: ignore[var-annotated]
304306
os.path.join(intermediate_path, f"{output_base_name}-{i}.bin"),
305307
_torch_to_numpy_dtype_dict[output_dtype],
306308
)
@@ -362,7 +364,7 @@ def run_corstone(
362364
executorch_program_manager: ExecutorchProgramManager,
363365
inputs: Tuple[torch.Tensor],
364366
intermediate_path: str | Path,
365-
target_board: Literal["corestone-300", "corestone-320"],
367+
target_board: str,
366368
elf_path: str | Path,
367369
timeout: int = 120, # s
368370
) -> list[torch.Tensor]:
@@ -759,7 +761,7 @@ def run_tosa_graph(
759761
inputs_np = [torch_tensor_to_numpy(input_tensor) for input_tensor in inputs]
760762

761763
if isinstance(tosa_version, Tosa_1_00):
762-
import tosa_reference_model as reference_model
764+
import tosa_reference_model as reference_model # type: ignore[import-untyped]
763765

764766
debug_mode = "ALL" if logger.level <= logging.DEBUG else None
765767
outputs_np, status = reference_model.run(
@@ -781,7 +783,7 @@ def run_tosa_graph(
781783
# Convert output numpy arrays to tensors with same dim_order as the output nodes
782784
result = [
783785
numpy_to_torch_tensor(output_array, node)
784-
for output_array, node in zip(outputs_np, output_node.args[0])
786+
for output_array, node in zip(outputs_np, output_node.args[0]) # type: ignore[arg-type]
785787
]
786788

787789
return result

backends/arm/test/test_model.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
import subprocess
99
import sys
1010
import time
11+
from typing import Sequence
1112

1213

1314
def get_args():
@@ -96,7 +97,7 @@ def get_args():
9697
return args
9798

9899

99-
def run_external_cmd(cmd: []):
100+
def run_external_cmd(cmd: Sequence[str]) -> None:
100101
print("CALL:", *cmd, sep=" ")
101102
try:
102103
subprocess.check_call(cmd)

0 commit comments

Comments
 (0)