Skip to content

Commit cdd6c91

Browse files
Fix pyre issues
Address issues from pyre and add similar # pyre-ignores as in #7362. Signed-off-by: Oscar Andersson <[email protected]> Change-Id: I6feaa611dcd539b3b0d21a6a7dd696ef7db691ef
1 parent b8343a2 commit cdd6c91

File tree

10 files changed

+51
-28
lines changed

10 files changed

+51
-28
lines changed

backends/arm/_passes/annotate_decomposed_matmul.py

Lines changed: 7 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -5,18 +5,14 @@
55
# LICENSE file in the root directory of this source tree.
66

77
import itertools
8-
from typing import Any, Dict, List
98

109
import torch
1110
from executorch.backends.arm._passes.arm_pass_utils import create_node
1211
from executorch.backends.arm.tosa_quant_utils import dq_op, q_op
1312
from executorch.exir.dialects._ops import ops as exir_ops
1413
from executorch.exir.pass_base import ExportPass, PassResult
1514
from torch.fx import GraphModule
16-
from torch.fx.passes.utils.source_matcher_utils import (
17-
get_source_partitions,
18-
SourcePartition,
19-
)
15+
from torch.fx.passes.utils.source_matcher_utils import get_source_partitions
2016

2117

2218
class AnnotateDecomposedMatmulPass(ExportPass):
@@ -28,8 +24,8 @@ class AnnotateDecomposedMatmulPass(ExportPass):
2824
matmul-op (can be mm or bmm).
2925
"""
3026

31-
def call(self, graph_module: GraphModule):
32-
matmul_partitions: Dict[Any, List[SourcePartition]] = get_source_partitions(
27+
def call(self, graph_module: GraphModule) -> PassResult:
28+
matmul_partitions = get_source_partitions(
3329
graph_module.graph,
3430
[
3531
torch.matmul,
@@ -56,7 +52,7 @@ def call(self, graph_module: GraphModule):
5652
input_node = partition.input_nodes[i]
5753
matmul_input_node = matmul_args[i]
5854
# Remove partition input dq-node
59-
input_node.replace_all_uses_with(input_node.args[0])
55+
input_node.replace_all_uses_with(input_node.all_input_nodes[0])
6056
graph_module.graph.erase_node(input_node)
6157
input_node_qargs = input_node.args[1:]
6258
with graph_module.graph.inserting_before(matmul_node):
@@ -81,7 +77,9 @@ def call(self, graph_module: GraphModule):
8177
matmul_node.replace_all_uses_with(q_node)
8278
q_node.args = (matmul_node, *output_node_qargs)
8379
# Remove partition output q-node
84-
partition_output.replace_all_uses_with(partition_output.args[0])
80+
partition_output.replace_all_uses_with(
81+
partition_output.all_input_nodes[0]
82+
)
8583
graph_module.graph.erase_node(partition_output)
8684

8785
# retrace the graph to update the fake tensor types

backends/arm/_passes/fold_qdq_with_annotated_qparams_pass.py

Lines changed: 17 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -6,14 +6,20 @@
66

77
import copy
88

9-
from typing import cast, Iterable
9+
from typing import cast, Dict, Iterable, Set, Tuple
1010

1111
from executorch.backends.arm.tosa_quant_utils import QuantArgs
1212

1313
from executorch.exir.dialects._ops import ops as exir_ops
1414
from executorch.exir.dialects.edge._ops import EdgeOpOverload
1515

16-
from executorch.exir.pass_base import ExportPass, PassResult
16+
from executorch.exir.pass_base import (
17+
Argument,
18+
ExportPass,
19+
NodeMetadata,
20+
PassResult,
21+
ProxyValue,
22+
)
1723
from torch.fx import GraphModule, Node
1824

1925
q_op: EdgeOpOverload = exir_ops.edge.quantized_decomposed.quantize_per_tensor.default
@@ -82,7 +88,7 @@ def __init__(self, targeted_ops: Iterable[EdgeOpOverload]) -> None:
8288

8389
def fold_and_annotate_arg(
8490
self, graph_module: GraphModule, node: Node, arg_list: list[Node], i: int
85-
):
91+
) -> None:
8692
input_qparams = None
8793
nodes_to_remove = set()
8894
for arg in arg_list:
@@ -210,11 +216,17 @@ class RetraceFoldedDtypesPass(ExportPass):
210216
the output type of that matches the type of the output_qparams.
211217
"""
212218

213-
targeted_ops = {
219+
targeted_ops: Set[EdgeOpOverload] = {
214220
exir_ops.edge.aten.sum.dim_IntList,
215221
}
216222

217-
def call_operator(self, op, args, kwargs, meta):
223+
def call_operator(
224+
self,
225+
op, # pyre-ignore
226+
args: Tuple[Argument, ...],
227+
kwargs: Dict[str, Argument],
228+
meta: NodeMetadata,
229+
) -> ProxyValue:
218230
if op not in self.targeted_ops:
219231
return super().call_operator(op, args, kwargs, meta)
220232

backends/arm/_passes/insert_table_ops.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -4,14 +4,15 @@
44
# This source code is licensed under the BSD-style license found in the
55
# LICENSE file in the root directory of this source tree.
66

7-
from typing import Callable
7+
from typing import Callable, Dict
88

99
import torch
1010
from executorch.backends.arm._passes.arm_pass_utils import create_node
1111
from executorch.backends.arm.tosa_quant_utils import QuantArgs
1212
from executorch.exir import ExportedProgram
1313

1414
from executorch.exir.dialects._ops import ops as exir_ops
15+
from executorch.exir.dialects.edge._ops import EdgeOpOverload
1516

1617
from executorch.exir.pass_base import ExportPass, PassResult
1718
from torch.fx import GraphModule
@@ -22,7 +23,7 @@
2223

2324

2425
@impl(lib, "_table")
25-
def _table_impl(*args, **kwargs):
26+
def _table_impl(*args, **kwargs): # pyre-ignore
2627
return args[0]
2728

2829

@@ -34,7 +35,7 @@ class InsertTableOpsPass(ExportPass):
3435
which will be used to produce the table values in operators/op_table.py.
3536
"""
3637

37-
table_ops = {
38+
table_ops: Dict[EdgeOpOverload, Callable[[torch.Tensor], torch.Tensor]] = {
3839
exir_ops.edge.aten.exp.default: torch.exp,
3940
exir_ops.edge.aten.log.default: torch.log,
4041
exir_ops.edge.aten.reciprocal.default: torch.reciprocal,
@@ -43,7 +44,7 @@ class InsertTableOpsPass(ExportPass):
4344
exir_ops.edge.aten.tanh.default: torch.tanh,
4445
}
4546

46-
def __init__(self, exported_program: ExportedProgram):
47+
def __init__(self, exported_program: ExportedProgram) -> None:
4748
super().__init__()
4849
self.exported_program = exported_program
4950

backends/arm/operators/op_bmm.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,8 @@
99

1010
import serializer.tosa_serializer as ts
1111
import torch
12+
13+
# pyre-fixme[21]: 'Could not find a module corresponding to import `executorch.backends.arm._passes.fold_qdq_with_annotated_qparams_pass`.'
1214
from executorch.backends.arm._passes.fold_qdq_with_annotated_qparams_pass import (
1315
get_input_qparams,
1416
get_output_qparams,
@@ -49,7 +51,7 @@ def define_node(
4951
# for a later rescale.
5052

5153
if inputs[0].dtype == ts.DType.INT8:
52-
input_qparams = get_input_qparams(node)
54+
input_qparams = get_input_qparams(node) # pyre-ingore[16]
5355
input0_zp = input_qparams[0].zp
5456
input1_zp = input_qparams[1].zp
5557
bmm_result = tosa_graph.addIntermediate(output.shape, ts.DType.INT32)
@@ -71,9 +73,9 @@ def define_node(
7173

7274
# As INT8 accumulates into INT32, we need to rescale it back to INT8
7375
if output.dtype == ts.DType.INT8:
74-
output_qparams = get_output_qparams(node)[0]
76+
output_qparams = get_output_qparams(node)[0] # pyre-ignore[16]
7577
final_output_scale = (
76-
input_qparams[0].scale * input_qparams[1].scale
78+
input_qparams[0].scale * input_qparams[1].scale # pyre-ignore[61]
7779
) / output_qparams.scale
7880

7981
build_rescale(

backends/arm/operators/op_hardtanh.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,8 @@
88

99
import serializer.tosa_serializer as ts
1010
import torch
11+
12+
# pyre-fixme[21]: 'Could not find a module corresponding to import `executorch.backends.arm._passes.fold_qdq_with_annotated_qparams_pass`.'
1113
from executorch.backends.arm._passes.fold_qdq_with_annotated_qparams_pass import (
1214
get_input_qparams,
1315
)
@@ -39,7 +41,7 @@ def define_node(
3941

4042
if inputs[0].dtype == ts.DType.INT8:
4143
# Get quant parameters
42-
input_qparams = get_input_qparams(node)
44+
input_qparams = get_input_qparams(node) # pyre-ignore[16]
4345
qargs = input_qparams[0]
4446
# Convert to quantized representation
4547
clamp_min_qs = quantize_value(inputs[1].number, qargs)

backends/arm/operators/op_max_pool2d.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,8 @@
88

99
import serializer.tosa_serializer as ts
1010
import torch
11+
12+
# pyre-fixme[21]: 'Could not find a module corresponding to import `executorch.backends.arm._passes.fold_qdq_with_annotated_qparams_pass`.'
1113
from executorch.backends.arm._passes.fold_qdq_with_annotated_qparams_pass import (
1214
get_input_qparams,
1315
get_output_qparams,
@@ -49,12 +51,12 @@ def define_node(
4951
# Initilize zero point to zero.
5052
input_zp = 0
5153
if inputs[0].dtype == ts.DType.INT8:
52-
input_qparams = get_input_qparams(node)
54+
input_qparams = get_input_qparams(node) # pyre-ignore[16]
5355
input_zp = input_qparams[0].zp
5456

5557
output_zp = 0
5658
if output.dtype == ts.DType.INT8:
57-
output_qparams = get_output_qparams(node)
59+
output_qparams = get_output_qparams(node) # pyre-ignore[16]
5860
output_zp = output_qparams[0].zp
5961

6062
attr = ts.TosaSerializerAttribute()

backends/arm/operators/op_mm.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,8 @@
99

1010
import serializer.tosa_serializer as ts
1111
import torch
12+
13+
# pyre-fixme[21]: 'Could not find a module corresponding to import `executorch.backends.arm._passes.fold_qdq_with_annotated_qparams_pass`.'
1214
from executorch.backends.arm._passes.fold_qdq_with_annotated_qparams_pass import (
1315
get_input_qparams,
1416
get_output_qparams,
@@ -52,7 +54,7 @@ def define_node(
5254
# The output also needs to be rank 3
5355
output_new_shape = (1, output.shape[0], output.shape[1])
5456

55-
input_qparams = get_input_qparams(node)
57+
input_qparams = get_input_qparams(node) # pyre-ignore[16]
5658
assert len(input_qparams) == 2
5759
input0_qparams = input_qparams[0]
5860
input1_qparams = input_qparams[1]
@@ -78,7 +80,7 @@ def define_node(
7880
)
7981

8082
# As INT8 accumulates into INT32, we need to rescale it back to INT8
81-
output_qparams = get_output_qparams(node)
83+
output_qparams = get_output_qparams(node) # pyre-ignore[16]
8284
assert len(output_qparams) == 1
8385

8486
final_output_scale = (

backends/arm/operators/op_mul.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,8 @@
1212

1313
import serializer.tosa_serializer as ts
1414
import torch
15+
16+
# pyre-fixme[21]: 'Could not find a module corresponding to import `executorch.backends.arm._passes.fold_qdq_with_annotated_qparams_pass`.'
1517
from executorch.backends.arm._passes.fold_qdq_with_annotated_qparams_pass import (
1618
get_input_qparams,
1719
)
@@ -43,7 +45,7 @@ def define_node(
4345
assert inputs[0].dtype == inputs[1].dtype == output.dtype == ts.DType.INT8
4446
input_A = inputs[0]
4547
input_B = inputs[1]
46-
input_qparams = get_input_qparams(node)
48+
input_qparams = get_input_qparams(node) # pyre-ignore[16]
4749
input_A_qargs = input_qparams[0]
4850
input_B_qargs = input_qparams[1]
4951
input_A.shape = tutils.tosa_shape(input_A.shape, input_A.dim_order)

backends/arm/operators/op_relu.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,8 @@
88
import executorch.backends.arm.tosa_quant_utils as tqutils
99
import serializer.tosa_serializer as ts
1010
import torch.fx
11+
12+
# pyre-fixme[21]: 'Could not find a module corresponding to import `executorch.backends.arm._passes.fold_qdq_with_annotated_qparams_pass`.'
1113
from executorch.backends.arm._passes.fold_qdq_with_annotated_qparams_pass import (
1214
get_output_qparams,
1315
)

backends/arm/tosa_quant_utils.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -145,7 +145,7 @@ def quantize_value(self, x):
145145
self.qmax,
146146
).to(self.dtype)
147147

148-
def dequantize_value(self, qx: int) -> float:
148+
def dequantize_value(self, qx: torch.Tensor) -> torch.Tensor:
149149
return (qx - self.zp) * self.scale
150150

151151
def __eq__(self, other):

0 commit comments

Comments
 (0)