From f88d250a6750911a9b7e04b3a773d0b3b70246b0 Mon Sep 17 00:00:00 2001 From: Sebastian Larsson Date: Mon, 6 Oct 2025 09:09:17 +0200 Subject: [PATCH] Arm backend: Fix arg-type MyPy errors Introduce function `ensure_type()` which throws an exception if the expected dtype is incorrect. This solves a number of mypy errors spread around the codebase. Additionally, fix arg-type mypy errors. Change-Id: I2677944ac70b44b913d1817c4225707b2bc39e14 Signed-off-by: Sebastian Larsson --- .../arm/_passes/annotate_decomposed_matmul.py | 8 ++--- backends/arm/_passes/arm_pass_utils.py | 9 +++-- .../arm/_passes/scalars_to_attribute_pass.py | 4 +-- .../arm/_passes/to_tosa_memory_format_pass.py | 8 ++++- backends/arm/common/type.py | 28 +++++++++++++++ .../operator_support/index_tensor_support.py | 7 ++-- .../tosa_supported_operators.py | 8 ++--- backends/arm/quantizer/arm_quantizer.py | 2 +- .../arm/quantizer/quantization_annotator.py | 34 +++++++++++-------- backends/arm/test/tester/arm_tester.py | 6 ++-- backends/arm/tosa/partitioner.py | 4 ++- 11 files changed, 81 insertions(+), 37 deletions(-) create mode 100644 backends/arm/common/type.py diff --git a/backends/arm/_passes/annotate_decomposed_matmul.py b/backends/arm/_passes/annotate_decomposed_matmul.py index 666214ec267..e7f02d14cd1 100644 --- a/backends/arm/_passes/annotate_decomposed_matmul.py +++ b/backends/arm/_passes/annotate_decomposed_matmul.py @@ -52,7 +52,7 @@ def _match_partition_to_node( raise RuntimeError(f"Cannot find an input node which matches, {node}.") def call(self, graph_module: GraphModule) -> PassResult: - matmul_partitions = get_source_partitions( + matmul_partitions_map = get_source_partitions( graph_module.graph, [ torch.matmul, @@ -61,7 +61,7 @@ def call(self, graph_module: GraphModule) -> PassResult: None, ) matmul_partitions = list( - itertools.chain.from_iterable(matmul_partitions.values()) + itertools.chain.from_iterable(matmul_partitions_map.values()) ) matmul_targets = { exir_ops.edge.aten.bmm.default, @@ -89,7 +89,7 @@ def call(self, graph_module: GraphModule) -> PassResult: # Create new dq-node before matmul dq_node = create_node( graph=graph_module.graph, - op_target=cast(EdgeOpOverload, input_node.target), # type: ignore[arg-type] + op_target=cast(EdgeOpOverload, input_node.target), ) dq_node.args = (node, *input_node.args[1:]) matmul_node.replace_input_with(node, dq_node) @@ -110,7 +110,7 @@ def call(self, graph_module: GraphModule) -> PassResult: # Create q-node after matmul q_node = create_node( graph=graph_module.graph, - op_target=cast(EdgeOpOverload, partition_output.target), # type: ignore[arg-type] + op_target=cast(EdgeOpOverload, partition_output.target), ) matmul_node.replace_all_uses_with(q_node) q_node.args = (matmul_node, *partition_output.args[1:]) diff --git a/backends/arm/_passes/arm_pass_utils.py b/backends/arm/_passes/arm_pass_utils.py index 71e2030958f..777b4603547 100644 --- a/backends/arm/_passes/arm_pass_utils.py +++ b/backends/arm/_passes/arm_pass_utils.py @@ -14,8 +14,10 @@ import torch import torch.fx from executorch.backends.arm.common.debug import get_node_debug_info +from executorch.backends.arm.common.type import ensure_type from executorch.exir import ExportedProgram from executorch.exir.dialects._ops import ops as exir_ops +from executorch.exir.dialects.edge._ops import EdgeOpOverload from torch._export.utils import ( get_buffer, @@ -82,17 +84,18 @@ def get_param_tensor( elif is_lifted_tensor_constant(exp_prog, node): return get_lifted_tensor_constant(exp_prog, node) elif is_get_attr_node(node): + target_node = ensure_type(str, node.target) # This is a hack to support both lifted and unlifted graph try: - return getattr(node.graph.owning_module, node.target) # type: ignore[arg-type] + return getattr(node.graph.owning_module, target_node) except AttributeError: - return getattr(exp_prog.graph_module, node.target) # type: ignore[arg-type] + return getattr(exp_prog.graph_module, target_node) raise RuntimeError(f"unsupported param type, {node.op}.") def create_node( graph: torch.fx.Graph, - op_target: OpOverload, + op_target: OpOverload | EdgeOpOverload, args: tuple = (), kwargs: Optional[dict] = None, quantize: bool = False, diff --git a/backends/arm/_passes/scalars_to_attribute_pass.py b/backends/arm/_passes/scalars_to_attribute_pass.py index 5ca6a60e844..c7188fef077 100644 --- a/backends/arm/_passes/scalars_to_attribute_pass.py +++ b/backends/arm/_passes/scalars_to_attribute_pass.py @@ -50,7 +50,7 @@ def call(self, graph_module: GraphModule) -> PassResult: shape = get_first_fake_tensor(arg).shape biggest_rank = max(biggest_rank, len(shape)) - new_args = [] + new_args: list[Node | int] = [] for arg in n.args: if isinstance(arg, Node): new_args.append(arg) @@ -58,7 +58,7 @@ def call(self, graph_module: GraphModule) -> PassResult: if isinstance(arg, int) and not torch.is_floating_point( get_first_fake_tensor(n) ): - new_args.append(arg) # type: ignore[arg-type] + new_args.append(arg) continue prefix = "_tensor_constant_" diff --git a/backends/arm/_passes/to_tosa_memory_format_pass.py b/backends/arm/_passes/to_tosa_memory_format_pass.py index 3783f782610..1e0edd60763 100644 --- a/backends/arm/_passes/to_tosa_memory_format_pass.py +++ b/backends/arm/_passes/to_tosa_memory_format_pass.py @@ -261,13 +261,19 @@ def insert_tosa_transposes(self, graph_module: torch.fx.GraphModule): # Transpose outputs if they are in (N)NCHW format outputs = output_node.args[0] + if not isinstance(outputs, (list, tuple)): + raise TypeError( + f"Expected output node args to be a list or tuple, got {type(outputs)}" + ) output_dim_orders = output_node.meta.get("original_dim_orders") if output_dim_orders is None: raise RuntimeError( f"{AnnotateDecomposedMatmulPass.__name__} is required to run at the beginning of the pass pipeline when using {ToTosaMemoryFormatPass.__name__}." ) - for output_node_input, output_dim_order in zip(outputs, output_dim_orders): # type: ignore[arg-type] + for output_node_input, output_dim_order in zip( + outputs, output_dim_orders, strict=True + ): if output_dim_order in ( NCHW_ORDER, NNCHW_ORDER, diff --git a/backends/arm/common/type.py b/backends/arm/common/type.py new file mode 100644 index 00000000000..e53dc1ee769 --- /dev/null +++ b/backends/arm/common/type.py @@ -0,0 +1,28 @@ +# Copyright 2025 Arm Limited and/or its affiliates. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. +"""Type checking utilities.""" + +from typing import TypeVar + +T = TypeVar("T") + + +def ensure_type(expected_type: type[T], arg: object) -> T: + """Ensure that the argument is of the expected type. + + Args: + expected_type (type[T]): The expected type. + arg (object): The argument to check. + + Returns: + T: The argument, if it is of the expected type. + + """ + if isinstance(arg, expected_type): + return arg + + expected_name = getattr(expected_type, "__name__", str(expected_type)) + actual_name = type(arg).__name__ + raise TypeError(f"Expected value of type {expected_name}, got {actual_name!r}") diff --git a/backends/arm/operator_support/index_tensor_support.py b/backends/arm/operator_support/index_tensor_support.py index 92b0ce48a32..5de70c0a2de 100644 --- a/backends/arm/operator_support/index_tensor_support.py +++ b/backends/arm/operator_support/index_tensor_support.py @@ -14,6 +14,7 @@ import torch import torch.fx as fx from executorch.backends.arm._passes.arm_pass_utils import get_first_fake_tensor +from executorch.backends.arm.common.type import ensure_type from executorch.backends.arm.operator_support.tosa_supported_operators import ( register_tosa_support_check, SupportedTOSAOperatorCheck, @@ -137,7 +138,8 @@ def is_node_tosa_supported( return False # Usage 1 guard - fake_tensor = get_first_fake_tensor(index) # type: ignore[arg-type] + index = ensure_type(torch.fx.Node, index) + fake_tensor = get_first_fake_tensor(index) if len(fake_tensor.size()) > 3: self.reporter.report_reject( node, @@ -146,7 +148,8 @@ def is_node_tosa_supported( return False # Usage 3 guard - total_vals = math.prod(get_first_fake_tensor(node.args[0]).shape) # type: ignore[arg-type] + input_node = ensure_type(torch.fx.Node, node.args[0]) + total_vals = math.prod(get_first_fake_tensor(input_node).shape) if total_vals > torch.iinfo(torch.int32).max: self.reporter.report_reject( node, diff --git a/backends/arm/operator_support/tosa_supported_operators.py b/backends/arm/operator_support/tosa_supported_operators.py index f7857894d40..ba479818a81 100644 --- a/backends/arm/operator_support/tosa_supported_operators.py +++ b/backends/arm/operator_support/tosa_supported_operators.py @@ -219,7 +219,7 @@ def _is_matmul_node_supported( """ for graph_module in submodules.values(): graph_module = typing.cast(fx.GraphModule, graph_module) - matmul_partitions = get_source_partitions( + matmul_partitions_map = get_source_partitions( graph_module.graph, [ torch.matmul, @@ -228,7 +228,7 @@ def _is_matmul_node_supported( None, ) matmul_partitions = list( - itertools.chain.from_iterable(matmul_partitions.values()) + itertools.chain.from_iterable(matmul_partitions_map.values()) ) matched_partition = None for partition in matmul_partitions: @@ -406,9 +406,7 @@ def is_node_supported( if input_node.target in ComputeConstantOpsAOT.targeted_ops: # This is not perfect since the input_node can still be rejected by other checks but # this should cover the majority of cases. - if self.is_node_supported( - None, input_node # type: ignore[arg-type] #(we don't use 'submodules') - ): + if self.is_node_supported({}, input_node): continue self.reporter.report_reject( node, f"Non-constant int64 input {input_node.name}" diff --git a/backends/arm/quantizer/arm_quantizer.py b/backends/arm/quantizer/arm_quantizer.py index 2b0b028c5e4..e6b1358e7e0 100644 --- a/backends/arm/quantizer/arm_quantizer.py +++ b/backends/arm/quantizer/arm_quantizer.py @@ -374,7 +374,7 @@ def transform_for_annotation(self, model: GraphModule) -> GraphModule: # TODO: Fix the need to lazily import this. from executorch.backends.arm._passes import ArmPassManager - return ArmPassManager(self.tosa_spec).transform_for_annotation_pipeline( # type: ignore[arg-type] + return ArmPassManager(self.tosa_spec).transform_for_annotation_pipeline( graph_module=model ) diff --git a/backends/arm/quantizer/quantization_annotator.py b/backends/arm/quantizer/quantization_annotator.py index b429bacd738..ee7003aacb8 100644 --- a/backends/arm/quantizer/quantization_annotator.py +++ b/backends/arm/quantizer/quantization_annotator.py @@ -12,6 +12,7 @@ import torch.fx import torch.nn.functional as F from executorch.backends.arm.common.debug import get_node_debug_info +from executorch.backends.arm.common.type import ensure_type from executorch.backends.arm.quantizer import QuantizationConfig from torch._subclasses import FakeTensor @@ -510,7 +511,8 @@ def any_or_hardtanh_min_zero(n: Node): torch.ops.aten.minimum.default, torch.ops.aten.maximum.default, ): - shared_qspec = SharedQuantizationSpec((node.args[0], node)) # type: ignore[arg-type] + lhs_node = ensure_type(Node, node.args[0]) + shared_qspec = SharedQuantizationSpec((lhs_node, node)) quant_properties.quant_inputs = [ _QuantProperty(0, input_act_qspec), _QuantProperty( @@ -520,22 +522,24 @@ def any_or_hardtanh_min_zero(n: Node): ] quant_properties.quant_output = _QuantProperty(0, shared_qspec) elif node.target in (torch.ops.aten.where.self,): - shared_qspec = SharedQuantizationSpec(node.args[1]) # type: ignore[arg-type] + true_node = ensure_type(Node, node.args[1]) + shared_qspec = SharedQuantizationSpec(true_node) quant_properties.quant_inputs = [ _QuantProperty(1, shared_qspec), _QuantProperty(2, shared_qspec), ] quant_properties.quant_output = _QuantProperty(0, shared_qspec) elif node.target in _one_to_one_shared_input_or_input_act_qspec: + input_node = ensure_type(Node, node.args[0]) input_qspec = ( - SharedQuantizationSpec(node.args[0]) # type: ignore[arg-type] - if is_output_annotated(node.args[0]) # type: ignore[arg-type] + SharedQuantizationSpec(input_node) + if is_output_annotated(input_node) else input_act_qspec ) quant_properties.quant_inputs = [_QuantProperty(0, input_qspec)] quant_properties.quant_output = _QuantProperty( 0, - SharedQuantizationSpec((node.args[0], node)), # type: ignore[arg-type] + SharedQuantizationSpec((input_node, node)), ) elif node.target in ( torch.ops.aten.cat.default, @@ -550,15 +554,12 @@ def any_or_hardtanh_min_zero(n: Node): ) if len(node.args[0]) == 0: raise ValueError("Expected non-empty list for node.args[0]") - - shared_qspec = SharedQuantizationSpec((node.args[0][0], node)) # type: ignore[arg-type] + inputs = [ensure_type(Node, element) for element in node.args[0]] + shared_qspec = SharedQuantizationSpec((inputs[0], node)) quant_properties.quant_inputs = [ _QuantProperty( 0, - [ - input_act_qspec if n == node.args[0][0] else shared_qspec # type: ignore[misc] - for n in node.args[0] - ], + [input_act_qspec if n == inputs[0] else shared_qspec for n in inputs], ) ] quant_properties.quant_output = _QuantProperty(0, shared_qspec) @@ -566,10 +567,11 @@ def any_or_hardtanh_min_zero(n: Node): quant_properties.quant_inputs = [_QuantProperty(0, input_act_qspec)] quant_properties.quant_output = _QuantProperty(0, output_act_qspec) elif node.target in _one_to_one_shared_input_qspec: + input_node = ensure_type(Node, node.args[0]) quant_properties.quant_inputs = [_QuantProperty(0, input_act_qspec)] quant_properties.quant_output = _QuantProperty( 0, - SharedQuantizationSpec((node.args[0], node)), # type: ignore[arg-type] + SharedQuantizationSpec((input_node, node)), ) elif node.target in [ torch.ops.aten.eq.Tensor, @@ -578,7 +580,8 @@ def any_or_hardtanh_min_zero(n: Node): torch.ops.aten.le.Tensor, torch.ops.aten.lt.Tensor, ]: - shared_qspec = SharedQuantizationSpec((node.args[0], node)) # type: ignore[arg-type] + input_node = ensure_type(Node, node.args[0]) + shared_qspec = SharedQuantizationSpec((input_node, node)) quant_properties.quant_inputs = [ _QuantProperty(0, input_act_qspec), _QuantProperty( @@ -596,9 +599,10 @@ def any_or_hardtanh_min_zero(n: Node): quant_properties.quant_inputs = [] quant_properties.quant_output = _QuantProperty(0, output_act_qspec) elif node.target in [operator.getitem]: - if not is_output_annotated(node.args[0]): # type: ignore[arg-type] + input_node = ensure_type(Node, node.args[0]) + if not is_output_annotated(input_node): return None - shared_qspec = SharedQuantizationSpec(node.args[0]) # type: ignore[arg-type] + shared_qspec = SharedQuantizationSpec(input_node) quant_properties.quant_inputs = [_QuantProperty(0, shared_qspec)] quant_properties.quant_output = _QuantProperty(0, shared_qspec) else: diff --git a/backends/arm/test/tester/arm_tester.py b/backends/arm/test/tester/arm_tester.py index 44b1a7aef13..7be249609b0 100644 --- a/backends/arm/test/tester/arm_tester.py +++ b/backends/arm/test/tester/arm_tester.py @@ -604,9 +604,9 @@ def run_transform_for_annotation_pipeline( # We need to clone the artifact in order to ensure that the state_dict is preserved after passes are run. artifact = self.get_artifact(stage) if self.cur == StageType.EXPORT: - new_gm = ArmPassManager(self.compile_spec.tosa_spec).transform_for_annotation_pipeline( # type: ignore[arg-type] - graph_module=artifact.graph_module - ) + new_gm = ArmPassManager( + self.compile_spec.tosa_spec + ).transform_for_annotation_pipeline(graph_module=artifact.graph_module) else: raise RuntimeError("Can only run passes on Export stage.") _copy_module(artifact.graph_module, new_gm) diff --git a/backends/arm/tosa/partitioner.py b/backends/arm/tosa/partitioner.py index 6eb1dcbef72..c00e003861c 100644 --- a/backends/arm/tosa/partitioner.py +++ b/backends/arm/tosa/partitioner.py @@ -22,6 +22,7 @@ from executorch.backends.arm._passes.convert_expand_copy_to_repeat import ( calculate_multiples, ) +from executorch.backends.arm.common.type import ensure_type from executorch.backends.arm.constants import DQ_OPS, Q_OPS from executorch.backends.arm.operator_support.tosa_supported_operators import ( tosa_support_factory, @@ -86,7 +87,8 @@ def is_noop_to_dim_order_copy(node: torch.fx.node.Node) -> bool: if node.target != exir_ops.edge.dim_order_ops._to_dim_order_copy.default: return False else: - return node.meta.get("dtype") == get_first_fake_tensor(node.args[0]).dtype # type: ignore[arg-type] + input_node = ensure_type(torch.fx.Node, node.args[0]) + return node.meta.get("dtype") == get_first_fake_tensor(input_node).dtype def is_noop_expand(node: torch.fx.node.Node) -> bool: