Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 0 additions & 1 deletion .lintrunner.toml
Original file line number Diff line number Diff line change
Expand Up @@ -320,7 +320,6 @@ exclude_patterns = [
'**/third-party/**',
'scripts/check_binary_dependencies.py',
'profiler/test/test_profiler_e2e.py',
'backends/arm/test/**',
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

❤️

]
command = [
'python',
Expand Down
16 changes: 11 additions & 5 deletions backends/arm/test/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,14 +13,20 @@
from typing import Any

import pytest
from executorch.backends.arm.arm_backend import ArmCompileSpecBuilder
from executorch.backends.arm.test.runner_utils import (
from executorch.backends.arm.arm_backend import ( # type: ignore[import-not-found]
ArmCompileSpecBuilder,
)
from executorch.backends.arm.test.runner_utils import ( # type: ignore[import-not-found]
arm_executor_runner_exists,
corstone300_installed,
corstone320_installed,
)
from executorch.backends.arm.tosa_specification import TosaSpecification
from executorch.exir.backend.compile_spec_schema import CompileSpec
from executorch.backends.arm.tosa_specification import ( # type: ignore[import-not-found]
TosaSpecification,
)
from executorch.exir.backend.compile_spec_schema import ( # type: ignore[import-not-found]
CompileSpec,
)


def get_time_formatted_path(path: str, log_prefix: str) -> str:
Expand Down Expand Up @@ -193,7 +199,7 @@ def get_u85_compile_spec_unbuilt(


def parametrize(
arg_name: str, test_data: dict[str, Any], xfails: dict[str, str] = None
arg_name: str, test_data: dict[str, Any], xfails: dict[str, str] | None = None
):
"""
Custom version of pytest.mark.parametrize with some syntatic sugar and added xfail functionality
Expand Down
2 changes: 1 addition & 1 deletion backends/arm/test/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
import pytest

try:
import tosa_reference_model
import tosa_reference_model # type: ignore[import-untyped]
except ImportError:
logging.warning("tosa_reference_model not found, can't run reference model tests")
tosa_reference_model = None
Expand Down
52 changes: 33 additions & 19 deletions backends/arm/test/runner_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,22 +20,36 @@

logger = logging.getLogger(__name__)
try:
import tosa_reference_model
import tosa_reference_model # type: ignore[import-untyped]
except ImportError:
tosa_reference_model = None
from executorch.backends.arm.arm_backend import get_tosa_spec, is_tosa

from executorch.backends.arm.test.conftest import is_option_enabled
from executorch.backends.arm.tosa_specification import TosaSpecification
from executorch.exir import ExecutorchProgramManager, ExportedProgram
from executorch.exir.backend.compile_spec_schema import CompileSpec
from executorch.exir.lowered_backend_module import LoweredBackendModule
from executorch.backends.arm.arm_backend import ( # type: ignore[import-not-found]
get_tosa_spec,
is_tosa,
)

from executorch.backends.arm.test.conftest import ( # type: ignore[import-not-found]
is_option_enabled,
)
from executorch.backends.arm.tosa_specification import ( # type: ignore[import-not-found]
TosaSpecification,
)
from executorch.exir import ( # type: ignore[import-not-found]
ExecutorchProgramManager,
ExportedProgram,
)
from executorch.exir.backend.compile_spec_schema import ( # type: ignore[import-not-found]
CompileSpec,
)
from executorch.exir.lowered_backend_module import ( # type: ignore[import-not-found]
LoweredBackendModule,
)
from packaging.version import Version
from torch.fx.node import Node

from torch.overrides import TorchFunctionMode
from torch.testing._internal.common_utils import torch_to_numpy_dtype_dict
from tosa import TosaGraph
from tosa import TosaGraph # type: ignore[import-untyped]

logger = logging.getLogger(__name__)
logger.setLevel(logging.CRITICAL)
Expand Down Expand Up @@ -148,16 +162,16 @@ def get_output_quantization_params(
for node in output_nodes:
if node.target == torch.ops.quantized_decomposed.dequantize_per_tensor.default:
quant_params[node] = QuantizationParams(
node_name=node.args[0].name,
scale=node.args[1],
zp=node.args[2],
qmin=node.args[3],
qmax=node.args[4],
dtype=node.args[5],
node_name=node.args[0].name, # type: ignore
scale=node.args[1], # type: ignore
zp=node.args[2], # type: ignore
qmin=node.args[3], # type: ignore
qmax=node.args[4], # type: ignore
dtype=node.args[5], # type: ignore
)
else:
quant_params[node] = None
return quant_params
quant_params[node] = None # type: ignore[assignment]
return quant_params # type: ignore[return-value]


class TosaReferenceModelDispatch(TorchFunctionMode):
Expand Down Expand Up @@ -243,7 +257,7 @@ def run_corstone(
input_names = get_input_names(exported_program)
input_paths = []
for input_name, input_ in zip(input_names, inputs):
input_path = save_bytes(intermediate_path, input_, input_name)
input_path = save_bytes(intermediate_path.as_posix(), input_, input_name)
input_paths.append(input_path)

out_path = os.path.join(intermediate_path, "out")
Expand Down Expand Up @@ -348,7 +362,7 @@ def prep_data_for_save(
quant_param: Optional[QuantizationParams] = None,
):
if isinstance(data, torch.Tensor):
data_np = np.array(data.detach(), order="C").astype(
data_np = np.array(data.detach(), order="C").astype( # type: ignore[var-annotated]
torch_to_numpy_dtype_dict[data.dtype]
)
else:
Expand Down
2 changes: 1 addition & 1 deletion backends/arm/test/test_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,7 @@ def get_args():
return args


def run_external_cmd(cmd: []):
def run_external_cmd(cmd: list[str]):
print("CALL:", *cmd, sep=" ")
try:
subprocess.check_call(cmd)
Expand Down
Loading