Skip to content

Commit a71332d

Browse files
committed
Revert "Arm backend: Fix arg-type MyPy errors (pytorch#15367)"
This reverts commit c66078c.
1 parent 16f7f7a commit a71332d

File tree

11 files changed

+37
-81
lines changed

11 files changed

+37
-81
lines changed

backends/arm/_passes/annotate_decomposed_matmul.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -51,7 +51,7 @@ def _match_partition_to_node(
5151
raise RuntimeError(f"Cannot find an input node which matches, {node}.")
5252

5353
def call(self, graph_module: GraphModule) -> PassResult:
54-
matmul_partitions_map = get_source_partitions(
54+
matmul_partitions = get_source_partitions(
5555
graph_module.graph,
5656
[
5757
torch.matmul,
@@ -60,7 +60,7 @@ def call(self, graph_module: GraphModule) -> PassResult:
6060
None,
6161
)
6262
matmul_partitions = list(
63-
itertools.chain.from_iterable(matmul_partitions_map.values())
63+
itertools.chain.from_iterable(matmul_partitions.values())
6464
)
6565
matmul_targets = {
6666
exir_ops.edge.aten.bmm.default,
@@ -88,7 +88,7 @@ def call(self, graph_module: GraphModule) -> PassResult:
8888
# Create new dq-node before matmul
8989
dq_node = create_node(
9090
graph=graph_module.graph,
91-
op_target=cast(EdgeOpOverload, input_node.target),
91+
op_target=cast(EdgeOpOverload, input_node.target), # type: ignore[arg-type]
9292
)
9393
dq_node.args = (node, *input_node.args[1:])
9494
matmul_node.replace_input_with(node, dq_node)
@@ -109,7 +109,7 @@ def call(self, graph_module: GraphModule) -> PassResult:
109109
# Create q-node after matmul
110110
q_node = create_node(
111111
graph=graph_module.graph,
112-
op_target=cast(EdgeOpOverload, partition_output.target),
112+
op_target=cast(EdgeOpOverload, partition_output.target), # type: ignore[arg-type]
113113
)
114114
matmul_node.replace_all_uses_with(q_node)
115115
q_node.args = (matmul_node, *partition_output.args[1:])

backends/arm/_passes/arm_pass_utils.py

Lines changed: 3 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -13,10 +13,8 @@
1313
import torch
1414
import torch.fx
1515
from executorch.backends.arm.common.debug import get_node_debug_info
16-
from executorch.backends.arm.common.type import ensure_type
1716
from executorch.exir import ExportedProgram
1817
from executorch.exir.dialects._ops import ops as exir_ops
19-
from executorch.exir.dialects.edge._ops import EdgeOpOverload
2018

2119
from torch._export.utils import (
2220
get_buffer,
@@ -83,18 +81,17 @@ def get_param_tensor(
8381
elif is_lifted_tensor_constant(exp_prog, node):
8482
return get_lifted_tensor_constant(exp_prog, node)
8583
elif is_get_attr_node(node):
86-
target_node = ensure_type(str, node.target)
8784
# This is a hack to support both lifted and unlifted graph
8885
try:
89-
return getattr(node.graph.owning_module, target_node)
86+
return getattr(node.graph.owning_module, node.target) # type: ignore[arg-type]
9087
except AttributeError:
91-
return getattr(exp_prog.graph_module, target_node)
88+
return getattr(exp_prog.graph_module, node.target) # type: ignore[arg-type]
9289
raise RuntimeError(f"unsupported param type, {node.op}.")
9390

9491

9592
def create_node(
9693
graph: torch.fx.Graph,
97-
op_target: OpOverload | EdgeOpOverload,
94+
op_target: OpOverload,
9895
args: tuple = (),
9996
kwargs: Optional[dict] = None,
10097
quantize: bool = False,

backends/arm/_passes/scalars_to_attribute_pass.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -49,15 +49,15 @@ def call(self, graph_module: GraphModule) -> PassResult:
4949
shape = get_first_fake_tensor(arg).shape
5050
biggest_rank = max(biggest_rank, len(shape))
5151

52-
new_args: list[Node | int] = []
52+
new_args = []
5353
for arg in n.args:
5454
if isinstance(arg, Node):
5555
new_args.append(arg)
5656
continue
5757
if isinstance(arg, int) and not torch.is_floating_point(
5858
get_first_fake_tensor(n)
5959
):
60-
new_args.append(arg)
60+
new_args.append(arg) # type: ignore[arg-type]
6161
continue
6262

6363
prefix = "_tensor_constant_"

backends/arm/_passes/to_tosa_memory_format_pass.py

Lines changed: 1 addition & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -259,19 +259,13 @@ def insert_tosa_transposes(self, graph_module: torch.fx.GraphModule):
259259

260260
# Transpose outputs if they are in (N)NCHW format
261261
outputs = output_node.args[0]
262-
if not isinstance(outputs, (list, tuple)):
263-
raise TypeError(
264-
f"Expected output node args to be a list or tuple, got {type(outputs)}"
265-
)
266262
output_dim_orders = output_node.meta.get("original_dim_orders")
267263
if output_dim_orders is None:
268264
raise RuntimeError(
269265
f"{AnnotateDecomposedMatmulPass.__name__} is required to run at the beginning of the pass pipeline when using {ToTosaMemoryFormatPass.__name__}."
270266
)
271267

272-
for output_node_input, output_dim_order in zip(
273-
outputs, output_dim_orders, strict=True
274-
):
268+
for output_node_input, output_dim_order in zip(outputs, output_dim_orders): # type: ignore[arg-type]
275269
if output_dim_order in (
276270
NCHW_ORDER,
277271
NNCHW_ORDER,

backends/arm/common/type.py

Lines changed: 0 additions & 28 deletions
This file was deleted.

backends/arm/operator_support/index_tensor_support.py

Lines changed: 2 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,6 @@
1414
import torch
1515
import torch.fx as fx
1616
from executorch.backends.arm._passes.arm_pass_utils import get_first_fake_tensor
17-
from executorch.backends.arm.common.type import ensure_type
1817
from executorch.backends.arm.operator_support.tosa_supported_operators import (
1918
register_tosa_support_check,
2019
SupportedTOSAOperatorCheck,
@@ -138,8 +137,7 @@ def is_node_tosa_supported(
138137
return False
139138

140139
# Usage 1 guard
141-
index = ensure_type(torch.fx.Node, index)
142-
fake_tensor = get_first_fake_tensor(index)
140+
fake_tensor = get_first_fake_tensor(index) # type: ignore[arg-type]
143141
if len(fake_tensor.size()) > 3:
144142
self.reporter.report_reject(
145143
node,
@@ -148,8 +146,7 @@ def is_node_tosa_supported(
148146
return False
149147

150148
# Usage 3 guard
151-
input_node = ensure_type(torch.fx.Node, node.args[0])
152-
total_vals = math.prod(get_first_fake_tensor(input_node).shape)
149+
total_vals = math.prod(get_first_fake_tensor(node.args[0]).shape) # type: ignore[arg-type]
153150
if total_vals > torch.iinfo(torch.int32).max:
154151
self.reporter.report_reject(
155152
node,

backends/arm/operator_support/tosa_supported_operators.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -219,7 +219,7 @@ def _is_matmul_node_supported(
219219
"""
220220
for graph_module in submodules.values():
221221
graph_module = typing.cast(fx.GraphModule, graph_module)
222-
matmul_partitions_map = get_source_partitions(
222+
matmul_partitions = get_source_partitions(
223223
graph_module.graph,
224224
[
225225
torch.matmul,
@@ -228,7 +228,7 @@ def _is_matmul_node_supported(
228228
None,
229229
)
230230
matmul_partitions = list(
231-
itertools.chain.from_iterable(matmul_partitions_map.values())
231+
itertools.chain.from_iterable(matmul_partitions.values())
232232
)
233233
matched_partition = None
234234
for partition in matmul_partitions:
@@ -406,7 +406,9 @@ def is_node_supported(
406406
if input_node.target in ComputeConstantOpsAOT.targeted_ops:
407407
# This is not perfect since the input_node can still be rejected by other checks but
408408
# this should cover the majority of cases.
409-
if self.is_node_supported({}, input_node):
409+
if self.is_node_supported(
410+
None, input_node # type: ignore[arg-type] #(we don't use 'submodules')
411+
):
410412
continue
411413
self.reporter.report_reject(
412414
node, f"Non-constant int64 input {input_node.name}"

backends/arm/quantizer/arm_quantizer.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -374,7 +374,7 @@ def transform_for_annotation(self, model: GraphModule) -> GraphModule:
374374
# TODO: Fix the need to lazily import this.
375375
from executorch.backends.arm._passes import ArmPassManager
376376

377-
return ArmPassManager(self.tosa_spec).transform_for_annotation_pipeline(
377+
return ArmPassManager(self.tosa_spec).transform_for_annotation_pipeline( # type: ignore[arg-type]
378378
graph_module=model
379379
)
380380

backends/arm/quantizer/quantization_annotator.py

Lines changed: 15 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,6 @@
1212
import torch.fx
1313
import torch.nn.functional as F
1414
from executorch.backends.arm.common.debug import get_node_debug_info
15-
from executorch.backends.arm.common.type import ensure_type
1615
from executorch.backends.arm.quantizer import QuantizationConfig
1716
from torch._subclasses import FakeTensor
1817

@@ -511,8 +510,7 @@ def any_or_hardtanh_min_zero(n: Node):
511510
torch.ops.aten.minimum.default,
512511
torch.ops.aten.maximum.default,
513512
):
514-
lhs_node = ensure_type(Node, node.args[0])
515-
shared_qspec = SharedQuantizationSpec((lhs_node, node))
513+
shared_qspec = SharedQuantizationSpec((node.args[0], node)) # type: ignore[arg-type]
516514
quant_properties.quant_inputs = [
517515
_QuantProperty(0, input_act_qspec),
518516
_QuantProperty(
@@ -522,24 +520,22 @@ def any_or_hardtanh_min_zero(n: Node):
522520
]
523521
quant_properties.quant_output = _QuantProperty(0, shared_qspec)
524522
elif node.target in (torch.ops.aten.where.self,):
525-
true_node = ensure_type(Node, node.args[1])
526-
shared_qspec = SharedQuantizationSpec(true_node)
523+
shared_qspec = SharedQuantizationSpec(node.args[1]) # type: ignore[arg-type]
527524
quant_properties.quant_inputs = [
528525
_QuantProperty(1, shared_qspec),
529526
_QuantProperty(2, shared_qspec),
530527
]
531528
quant_properties.quant_output = _QuantProperty(0, shared_qspec)
532529
elif node.target in _one_to_one_shared_input_or_input_act_qspec:
533-
input_node = ensure_type(Node, node.args[0])
534530
input_qspec = (
535-
SharedQuantizationSpec(input_node)
536-
if is_output_annotated(input_node)
531+
SharedQuantizationSpec(node.args[0]) # type: ignore[arg-type]
532+
if is_output_annotated(node.args[0]) # type: ignore[arg-type]
537533
else input_act_qspec
538534
)
539535
quant_properties.quant_inputs = [_QuantProperty(0, input_qspec)]
540536
quant_properties.quant_output = _QuantProperty(
541537
0,
542-
SharedQuantizationSpec((input_node, node)),
538+
SharedQuantizationSpec((node.args[0], node)), # type: ignore[arg-type]
543539
)
544540
elif node.target in (
545541
torch.ops.aten.cat.default,
@@ -554,24 +550,26 @@ def any_or_hardtanh_min_zero(n: Node):
554550
)
555551
if len(node.args[0]) == 0:
556552
raise ValueError("Expected non-empty list for node.args[0]")
557-
inputs = [ensure_type(Node, element) for element in node.args[0]]
558-
shared_qspec = SharedQuantizationSpec((inputs[0], node))
553+
554+
shared_qspec = SharedQuantizationSpec((node.args[0][0], node)) # type: ignore[arg-type]
559555
quant_properties.quant_inputs = [
560556
_QuantProperty(
561557
0,
562-
[input_act_qspec if n == inputs[0] else shared_qspec for n in inputs],
558+
[
559+
input_act_qspec if n == node.args[0][0] else shared_qspec # type: ignore[misc]
560+
for n in node.args[0]
561+
],
563562
)
564563
]
565564
quant_properties.quant_output = _QuantProperty(0, shared_qspec)
566565
elif node.target in _one_to_one:
567566
quant_properties.quant_inputs = [_QuantProperty(0, input_act_qspec)]
568567
quant_properties.quant_output = _QuantProperty(0, output_act_qspec)
569568
elif node.target in _one_to_one_shared_input_qspec:
570-
input_node = ensure_type(Node, node.args[0])
571569
quant_properties.quant_inputs = [_QuantProperty(0, input_act_qspec)]
572570
quant_properties.quant_output = _QuantProperty(
573571
0,
574-
SharedQuantizationSpec((input_node, node)),
572+
SharedQuantizationSpec((node.args[0], node)), # type: ignore[arg-type]
575573
)
576574
elif node.target in [
577575
torch.ops.aten.eq.Tensor,
@@ -580,8 +578,7 @@ def any_or_hardtanh_min_zero(n: Node):
580578
torch.ops.aten.le.Tensor,
581579
torch.ops.aten.lt.Tensor,
582580
]:
583-
input_node = ensure_type(Node, node.args[0])
584-
shared_qspec = SharedQuantizationSpec((input_node, node))
581+
shared_qspec = SharedQuantizationSpec((node.args[0], node)) # type: ignore[arg-type]
585582
quant_properties.quant_inputs = [
586583
_QuantProperty(0, input_act_qspec),
587584
_QuantProperty(
@@ -599,10 +596,9 @@ def any_or_hardtanh_min_zero(n: Node):
599596
quant_properties.quant_inputs = []
600597
quant_properties.quant_output = _QuantProperty(0, output_act_qspec)
601598
elif node.target in [operator.getitem]:
602-
input_node = ensure_type(Node, node.args[0])
603-
if not is_output_annotated(input_node):
599+
if not is_output_annotated(node.args[0]): # type: ignore[arg-type]
604600
return None
605-
shared_qspec = SharedQuantizationSpec(input_node)
601+
shared_qspec = SharedQuantizationSpec(node.args[0]) # type: ignore[arg-type]
606602
quant_properties.quant_inputs = [_QuantProperty(0, shared_qspec)]
607603
quant_properties.quant_output = _QuantProperty(0, shared_qspec)
608604
else:

backends/arm/test/tester/arm_tester.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -604,9 +604,9 @@ def run_transform_for_annotation_pipeline(
604604
# We need to clone the artifact in order to ensure that the state_dict is preserved after passes are run.
605605
artifact = self.get_artifact(stage)
606606
if self.cur == StageType.EXPORT:
607-
new_gm = ArmPassManager(
608-
self.compile_spec.tosa_spec
609-
).transform_for_annotation_pipeline(graph_module=artifact.graph_module)
607+
new_gm = ArmPassManager(self.compile_spec.tosa_spec).transform_for_annotation_pipeline( # type: ignore[arg-type]
608+
graph_module=artifact.graph_module
609+
)
610610
else:
611611
raise RuntimeError("Can only run passes on Export stage.")
612612
_copy_module(artifact.graph_module, new_gm)

0 commit comments

Comments
 (0)