Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 4 additions & 4 deletions backends/arm/_passes/annotate_decomposed_matmul.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ def _match_partition_to_node(
raise RuntimeError(f"Cannot find an input node which matches, {node}.")

def call(self, graph_module: GraphModule) -> PassResult:
matmul_partitions = get_source_partitions(
matmul_partitions_map = get_source_partitions(
graph_module.graph,
[
torch.matmul,
Expand All @@ -61,7 +61,7 @@ def call(self, graph_module: GraphModule) -> PassResult:
None,
)
matmul_partitions = list(
itertools.chain.from_iterable(matmul_partitions.values())
itertools.chain.from_iterable(matmul_partitions_map.values())
)
matmul_targets = {
exir_ops.edge.aten.bmm.default,
Expand Down Expand Up @@ -89,7 +89,7 @@ def call(self, graph_module: GraphModule) -> PassResult:
# Create new dq-node before matmul
dq_node = create_node(
graph=graph_module.graph,
op_target=cast(EdgeOpOverload, input_node.target), # type: ignore[arg-type]
op_target=cast(EdgeOpOverload, input_node.target),
)
dq_node.args = (node, *input_node.args[1:])
matmul_node.replace_input_with(node, dq_node)
Expand All @@ -110,7 +110,7 @@ def call(self, graph_module: GraphModule) -> PassResult:
# Create q-node after matmul
q_node = create_node(
graph=graph_module.graph,
op_target=cast(EdgeOpOverload, partition_output.target), # type: ignore[arg-type]
op_target=cast(EdgeOpOverload, partition_output.target),
)
matmul_node.replace_all_uses_with(q_node)
q_node.args = (matmul_node, *partition_output.args[1:])
Expand Down
9 changes: 6 additions & 3 deletions backends/arm/_passes/arm_pass_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,10 @@
import torch
import torch.fx
from executorch.backends.arm.common.debug import get_node_debug_info
from executorch.backends.arm.common.type import ensure_type
from executorch.exir import ExportedProgram
from executorch.exir.dialects._ops import ops as exir_ops
from executorch.exir.dialects.edge._ops import EdgeOpOverload

from torch._export.utils import (
get_buffer,
Expand Down Expand Up @@ -82,17 +84,18 @@ def get_param_tensor(
elif is_lifted_tensor_constant(exp_prog, node):
return get_lifted_tensor_constant(exp_prog, node)
elif is_get_attr_node(node):
target_node = ensure_type(str, node.target)
# This is a hack to support both lifted and unlifted graph
try:
return getattr(node.graph.owning_module, node.target) # type: ignore[arg-type]
return getattr(node.graph.owning_module, target_node)
except AttributeError:
return getattr(exp_prog.graph_module, node.target) # type: ignore[arg-type]
return getattr(exp_prog.graph_module, target_node)
raise RuntimeError(f"unsupported param type, {node.op}.")


def create_node(
graph: torch.fx.Graph,
op_target: OpOverload,
op_target: OpOverload | EdgeOpOverload,
args: tuple = (),
kwargs: Optional[dict] = None,
quantize: bool = False,
Expand Down
4 changes: 2 additions & 2 deletions backends/arm/_passes/scalars_to_attribute_pass.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,15 +50,15 @@ def call(self, graph_module: GraphModule) -> PassResult:
shape = get_first_fake_tensor(arg).shape
biggest_rank = max(biggest_rank, len(shape))

new_args = []
new_args: list[Node | int] = []
for arg in n.args:
if isinstance(arg, Node):
new_args.append(arg)
continue
if isinstance(arg, int) and not torch.is_floating_point(
get_first_fake_tensor(n)
):
new_args.append(arg) # type: ignore[arg-type]
new_args.append(arg)
continue

prefix = "_tensor_constant_"
Expand Down
8 changes: 7 additions & 1 deletion backends/arm/_passes/to_tosa_memory_format_pass.py
Original file line number Diff line number Diff line change
Expand Up @@ -261,13 +261,19 @@ def insert_tosa_transposes(self, graph_module: torch.fx.GraphModule):

# Transpose outputs if they are in (N)NCHW format
outputs = output_node.args[0]
if not isinstance(outputs, (list, tuple)):
raise TypeError(
f"Expected output node args to be a list or tuple, got {type(outputs)}"
)
output_dim_orders = output_node.meta.get("original_dim_orders")
if output_dim_orders is None:
raise RuntimeError(
f"{AnnotateDecomposedMatmulPass.__name__} is required to run at the beginning of the pass pipeline when using {ToTosaMemoryFormatPass.__name__}."
)

for output_node_input, output_dim_order in zip(outputs, output_dim_orders): # type: ignore[arg-type]
for output_node_input, output_dim_order in zip(
outputs, output_dim_orders, strict=True
):
if output_dim_order in (
NCHW_ORDER,
NNCHW_ORDER,
Expand Down
28 changes: 28 additions & 0 deletions backends/arm/common/type.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
# Copyright 2025 Arm Limited and/or its affiliates.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
"""Type checking utilities."""

from typing import TypeVar

T = TypeVar("T")


def ensure_type(expected_type: type[T], arg: object) -> T:
"""Ensure that the argument is of the expected type.

Args:
expected_type (type[T]): The expected type.
arg (object): The argument to check.

Returns:
T: The argument, if it is of the expected type.

"""
if isinstance(arg, expected_type):
return arg

expected_name = getattr(expected_type, "__name__", str(expected_type))
actual_name = type(arg).__name__
raise TypeError(f"Expected value of type {expected_name}, got {actual_name!r}")
7 changes: 5 additions & 2 deletions backends/arm/operator_support/index_tensor_support.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
import torch
import torch.fx as fx
from executorch.backends.arm._passes.arm_pass_utils import get_first_fake_tensor
from executorch.backends.arm.common.type import ensure_type
from executorch.backends.arm.operator_support.tosa_supported_operators import (
register_tosa_support_check,
SupportedTOSAOperatorCheck,
Expand Down Expand Up @@ -137,7 +138,8 @@ def is_node_tosa_supported(
return False

# Usage 1 guard
fake_tensor = get_first_fake_tensor(index) # type: ignore[arg-type]
index = ensure_type(torch.fx.Node, index)
fake_tensor = get_first_fake_tensor(index)
if len(fake_tensor.size()) > 3:
self.reporter.report_reject(
node,
Expand All @@ -146,7 +148,8 @@ def is_node_tosa_supported(
return False

# Usage 3 guard
total_vals = math.prod(get_first_fake_tensor(node.args[0]).shape) # type: ignore[arg-type]
input_node = ensure_type(torch.fx.Node, node.args[0])
total_vals = math.prod(get_first_fake_tensor(input_node).shape)
if total_vals > torch.iinfo(torch.int32).max:
self.reporter.report_reject(
node,
Expand Down
8 changes: 3 additions & 5 deletions backends/arm/operator_support/tosa_supported_operators.py
Original file line number Diff line number Diff line change
Expand Up @@ -219,7 +219,7 @@ def _is_matmul_node_supported(
"""
for graph_module in submodules.values():
graph_module = typing.cast(fx.GraphModule, graph_module)
matmul_partitions = get_source_partitions(
matmul_partitions_map = get_source_partitions(
graph_module.graph,
[
torch.matmul,
Expand All @@ -228,7 +228,7 @@ def _is_matmul_node_supported(
None,
)
matmul_partitions = list(
itertools.chain.from_iterable(matmul_partitions.values())
itertools.chain.from_iterable(matmul_partitions_map.values())
)
matched_partition = None
for partition in matmul_partitions:
Expand Down Expand Up @@ -406,9 +406,7 @@ def is_node_supported(
if input_node.target in ComputeConstantOpsAOT.targeted_ops:
# This is not perfect since the input_node can still be rejected by other checks but
# this should cover the majority of cases.
if self.is_node_supported(
None, input_node # type: ignore[arg-type] #(we don't use 'submodules')
):
if self.is_node_supported({}, input_node):
continue
self.reporter.report_reject(
node, f"Non-constant int64 input {input_node.name}"
Expand Down
2 changes: 1 addition & 1 deletion backends/arm/quantizer/arm_quantizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -374,7 +374,7 @@ def transform_for_annotation(self, model: GraphModule) -> GraphModule:
# TODO: Fix the need to lazily import this.
from executorch.backends.arm._passes import ArmPassManager

return ArmPassManager(self.tosa_spec).transform_for_annotation_pipeline( # type: ignore[arg-type]
return ArmPassManager(self.tosa_spec).transform_for_annotation_pipeline(
graph_module=model
)

Expand Down
34 changes: 19 additions & 15 deletions backends/arm/quantizer/quantization_annotator.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
import torch.fx
import torch.nn.functional as F
from executorch.backends.arm.common.debug import get_node_debug_info
from executorch.backends.arm.common.type import ensure_type
from executorch.backends.arm.quantizer import QuantizationConfig
from torch._subclasses import FakeTensor

Expand Down Expand Up @@ -510,7 +511,8 @@ def any_or_hardtanh_min_zero(n: Node):
torch.ops.aten.minimum.default,
torch.ops.aten.maximum.default,
):
shared_qspec = SharedQuantizationSpec((node.args[0], node)) # type: ignore[arg-type]
lhs_node = ensure_type(Node, node.args[0])
shared_qspec = SharedQuantizationSpec((lhs_node, node))
quant_properties.quant_inputs = [
_QuantProperty(0, input_act_qspec),
_QuantProperty(
Expand All @@ -520,22 +522,24 @@ def any_or_hardtanh_min_zero(n: Node):
]
quant_properties.quant_output = _QuantProperty(0, shared_qspec)
elif node.target in (torch.ops.aten.where.self,):
shared_qspec = SharedQuantizationSpec(node.args[1]) # type: ignore[arg-type]
true_node = ensure_type(Node, node.args[1])
shared_qspec = SharedQuantizationSpec(true_node)
quant_properties.quant_inputs = [
_QuantProperty(1, shared_qspec),
_QuantProperty(2, shared_qspec),
]
quant_properties.quant_output = _QuantProperty(0, shared_qspec)
elif node.target in _one_to_one_shared_input_or_input_act_qspec:
input_node = ensure_type(Node, node.args[0])
input_qspec = (
SharedQuantizationSpec(node.args[0]) # type: ignore[arg-type]
if is_output_annotated(node.args[0]) # type: ignore[arg-type]
SharedQuantizationSpec(input_node)
if is_output_annotated(input_node)
else input_act_qspec
)
quant_properties.quant_inputs = [_QuantProperty(0, input_qspec)]
quant_properties.quant_output = _QuantProperty(
0,
SharedQuantizationSpec((node.args[0], node)), # type: ignore[arg-type]
SharedQuantizationSpec((input_node, node)),
)
elif node.target in (
torch.ops.aten.cat.default,
Expand All @@ -550,26 +554,24 @@ def any_or_hardtanh_min_zero(n: Node):
)
if len(node.args[0]) == 0:
raise ValueError("Expected non-empty list for node.args[0]")

shared_qspec = SharedQuantizationSpec((node.args[0][0], node)) # type: ignore[arg-type]
inputs = [ensure_type(Node, element) for element in node.args[0]]
shared_qspec = SharedQuantizationSpec((inputs[0], node))
quant_properties.quant_inputs = [
_QuantProperty(
0,
[
input_act_qspec if n == node.args[0][0] else shared_qspec # type: ignore[misc]
for n in node.args[0]
],
[input_act_qspec if n == inputs[0] else shared_qspec for n in inputs],
)
]
quant_properties.quant_output = _QuantProperty(0, shared_qspec)
elif node.target in _one_to_one:
quant_properties.quant_inputs = [_QuantProperty(0, input_act_qspec)]
quant_properties.quant_output = _QuantProperty(0, output_act_qspec)
elif node.target in _one_to_one_shared_input_qspec:
input_node = ensure_type(Node, node.args[0])
quant_properties.quant_inputs = [_QuantProperty(0, input_act_qspec)]
quant_properties.quant_output = _QuantProperty(
0,
SharedQuantizationSpec((node.args[0], node)), # type: ignore[arg-type]
SharedQuantizationSpec((input_node, node)),
)
elif node.target in [
torch.ops.aten.eq.Tensor,
Expand All @@ -578,7 +580,8 @@ def any_or_hardtanh_min_zero(n: Node):
torch.ops.aten.le.Tensor,
torch.ops.aten.lt.Tensor,
]:
shared_qspec = SharedQuantizationSpec((node.args[0], node)) # type: ignore[arg-type]
input_node = ensure_type(Node, node.args[0])
shared_qspec = SharedQuantizationSpec((input_node, node))
quant_properties.quant_inputs = [
_QuantProperty(0, input_act_qspec),
_QuantProperty(
Expand All @@ -596,9 +599,10 @@ def any_or_hardtanh_min_zero(n: Node):
quant_properties.quant_inputs = []
quant_properties.quant_output = _QuantProperty(0, output_act_qspec)
elif node.target in [operator.getitem]:
if not is_output_annotated(node.args[0]): # type: ignore[arg-type]
input_node = ensure_type(Node, node.args[0])
if not is_output_annotated(input_node):
return None
shared_qspec = SharedQuantizationSpec(node.args[0]) # type: ignore[arg-type]
shared_qspec = SharedQuantizationSpec(input_node)
quant_properties.quant_inputs = [_QuantProperty(0, shared_qspec)]
quant_properties.quant_output = _QuantProperty(0, shared_qspec)
else:
Expand Down
6 changes: 3 additions & 3 deletions backends/arm/test/tester/arm_tester.py
Original file line number Diff line number Diff line change
Expand Up @@ -604,9 +604,9 @@ def run_transform_for_annotation_pipeline(
# We need to clone the artifact in order to ensure that the state_dict is preserved after passes are run.
artifact = self.get_artifact(stage)
if self.cur == StageType.EXPORT:
new_gm = ArmPassManager(self.compile_spec.tosa_spec).transform_for_annotation_pipeline( # type: ignore[arg-type]
graph_module=artifact.graph_module
)
new_gm = ArmPassManager(
self.compile_spec.tosa_spec
).transform_for_annotation_pipeline(graph_module=artifact.graph_module)
else:
raise RuntimeError("Can only run passes on Export stage.")
_copy_module(artifact.graph_module, new_gm)
Expand Down
4 changes: 3 additions & 1 deletion backends/arm/tosa/partitioner.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
from executorch.backends.arm._passes.convert_expand_copy_to_repeat import (
calculate_multiples,
)
from executorch.backends.arm.common.type import ensure_type
from executorch.backends.arm.constants import DQ_OPS, Q_OPS
from executorch.backends.arm.operator_support.tosa_supported_operators import (
tosa_support_factory,
Expand Down Expand Up @@ -86,7 +87,8 @@ def is_noop_to_dim_order_copy(node: torch.fx.node.Node) -> bool:
if node.target != exir_ops.edge.dim_order_ops._to_dim_order_copy.default:
return False
else:
return node.meta.get("dtype") == get_first_fake_tensor(node.args[0]).dtype # type: ignore[arg-type]
input_node = ensure_type(torch.fx.Node, node.args[0])
return node.meta.get("dtype") == get_first_fake_tensor(input_node).dtype


def is_noop_expand(node: torch.fx.node.Node) -> bool:
Expand Down
Loading