Skip to content

Commit aba691a

Browse files
authored
Merge branch 'main' into toupstream/linear_int16
2 parents 51e5262 + 69f79b9 commit aba691a

File tree

73 files changed

+176
-199
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

73 files changed

+176
-199
lines changed

backends/arm/_passes/annotate_decomposed_matmul.py

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,6 @@
33
# This source code is licensed under the BSD-style license found in the
44
# LICENSE file in the root directory of this source tree.
55

6-
# pyre-unsafe
76

87
import itertools
98
import operator
@@ -52,7 +51,7 @@ def _match_partition_to_node(
5251
raise RuntimeError(f"Cannot find an input node which matches, {node}.")
5352

5453
def call(self, graph_module: GraphModule) -> PassResult:
55-
matmul_partitions = get_source_partitions(
54+
matmul_partitions_map = get_source_partitions(
5655
graph_module.graph,
5756
[
5857
torch.matmul,
@@ -61,7 +60,7 @@ def call(self, graph_module: GraphModule) -> PassResult:
6160
None,
6261
)
6362
matmul_partitions = list(
64-
itertools.chain.from_iterable(matmul_partitions.values())
63+
itertools.chain.from_iterable(matmul_partitions_map.values())
6564
)
6665
matmul_targets = {
6766
exir_ops.edge.aten.bmm.default,
@@ -89,7 +88,7 @@ def call(self, graph_module: GraphModule) -> PassResult:
8988
# Create new dq-node before matmul
9089
dq_node = create_node(
9190
graph=graph_module.graph,
92-
op_target=cast(EdgeOpOverload, input_node.target), # type: ignore[arg-type]
91+
op_target=cast(EdgeOpOverload, input_node.target),
9392
)
9493
dq_node.args = (node, *input_node.args[1:])
9594
matmul_node.replace_input_with(node, dq_node)
@@ -110,7 +109,7 @@ def call(self, graph_module: GraphModule) -> PassResult:
110109
# Create q-node after matmul
111110
q_node = create_node(
112111
graph=graph_module.graph,
113-
op_target=cast(EdgeOpOverload, partition_output.target), # type: ignore[arg-type]
112+
op_target=cast(EdgeOpOverload, partition_output.target),
114113
)
115114
matmul_node.replace_all_uses_with(q_node)
116115
q_node.args = (matmul_node, *partition_output.args[1:])

backends/arm/_passes/arm_pass.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,6 @@
33
# This source code is licensed under the BSD-style license found in the
44
# LICENSE file in the root directory of this source tree.
55

6-
# pyre-unsafe
76

87
import traceback
98
from abc import abstractmethod

backends/arm/_passes/arm_pass_manager.py

Lines changed: 2 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -5,8 +5,6 @@
55
# This source code is licensed under the BSD-style license found in the
66
# LICENSE file in the root directory of this source tree.
77

8-
# pyre-unsafe
9-
108

119
from collections import defaultdict
1210

@@ -194,7 +192,6 @@ def _tosa_INT_pipeline(self, exported_program: ExportedProgram) -> GraphModule:
194192
self.add_pass(ConvertExpandCopyToRepeatPass())
195193
self.add_pass(UnsqueezeBeforeRepeatPass())
196194
self.add_pass(CastInt64BuffersToInt32Pass(exported_program))
197-
self.add_pass(DecomposeSumPass())
198195
self.add_pass(DecomposeCumsumPass(exported_program))
199196
self.add_pass(Conv1dUnsqueezePass())
200197
self.add_pass(DecomposeMaxPool2DPass())
@@ -215,10 +212,11 @@ def _tosa_INT_pipeline(self, exported_program: ExportedProgram) -> GraphModule:
215212
self.add_pass(RewriteMatmulPass())
216213
self.add_pass(RewriteUpsamplePass())
217214
self.add_pass(FuseEqualPlaceholdersPass(exported_program))
215+
self.add_pass(InsertRescaleInt32Pass())
216+
self.add_pass(DecomposeSumPass())
218217
self.add_pass(ToTosaMemoryFormatPass(exported_program))
219218
self.add_pass(RemoveNoopPass())
220219
self.add_pass(InsertRescalePass())
221-
self.add_pass(InsertRescaleInt32Pass())
222220

223221
self.validate_constraints_mandatory()
224222
return self._transform(exported_program.graph_module)
@@ -361,7 +359,6 @@ def transform_for_annotation_pipeline(self, graph_module: GraphModule):
361359

362360
self.add_pass(ConvertMinMaxPass())
363361
self.add_pass(ReplaceInfValues())
364-
self.add_pass(DecomposeSumPass())
365362

366363
if not self.tosa_spec.is_U55_subset:
367364
# Uses where which is not supported on Ethos-U55

backends/arm/_passes/arm_pass_utils.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,6 @@
55
# This source code is licensed under the BSD-style license found in the
66
# LICENSE file in the root directory of this source tree.
77

8-
# pyre-unsafe
98

109
import traceback
1110
from inspect import isclass
@@ -14,8 +13,10 @@
1413
import torch
1514
import torch.fx
1615
from executorch.backends.arm.common.debug import get_node_debug_info
16+
from executorch.backends.arm.common.type import ensure_type
1717
from executorch.exir import ExportedProgram
1818
from executorch.exir.dialects._ops import ops as exir_ops
19+
from executorch.exir.dialects.edge._ops import EdgeOpOverload
1920

2021
from torch._export.utils import (
2122
get_buffer,
@@ -82,17 +83,18 @@ def get_param_tensor(
8283
elif is_lifted_tensor_constant(exp_prog, node):
8384
return get_lifted_tensor_constant(exp_prog, node)
8485
elif is_get_attr_node(node):
86+
target_node = ensure_type(str, node.target)
8587
# This is a hack to support both lifted and unlifted graph
8688
try:
87-
return getattr(node.graph.owning_module, node.target) # type: ignore[arg-type]
89+
return getattr(node.graph.owning_module, target_node)
8890
except AttributeError:
89-
return getattr(exp_prog.graph_module, node.target) # type: ignore[arg-type]
91+
return getattr(exp_prog.graph_module, target_node)
9092
raise RuntimeError(f"unsupported param type, {node.op}.")
9193

9294

9395
def create_node(
9496
graph: torch.fx.Graph,
95-
op_target: OpOverload,
97+
op_target: OpOverload | EdgeOpOverload,
9698
args: tuple = (),
9799
kwargs: Optional[dict] = None,
98100
quantize: bool = False,

backends/arm/_passes/cast_int64_pass.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,6 @@
33
# This source code is licensed under the BSD-style license found in the
44
# LICENSE file in the root directory of this source tree.
55

6-
# pyre-unsafe
76

87
import logging
98
from typing import Set, Type

backends/arm/_passes/convert_expand_copy_to_repeat.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,6 @@
33
# This source code is licensed under the BSD-style license found in the
44
# LICENSE file in the root directory of this source tree.
55

6-
# pyre-unsafe
76

87
import logging
98
from typing import cast, Set, Type

backends/arm/_passes/convert_int64_const_ops_to_int32.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,8 +3,6 @@
33
# This source code is licensed under the BSD-style license found in the
44
# LICENSE file in the root directory of this source tree.
55

6-
# pyre-unsafe
7-
86

97
import logging
108
from typing import Set, Type

backends/arm/_passes/convert_int64_output_ops_to_int32.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,8 +3,6 @@
33
# This source code is licensed under the BSD-style license found in the
44
# LICENSE file in the root directory of this source tree.
55

6-
# pyre-unsafe
7-
86

97
import logging
108
from typing import Set, Type

backends/arm/_passes/convert_int_pow_to_mul.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,6 @@
33
# This source code is licensed under the BSD-style license found in the
44
# LICENSE file in the root directory of this source tree.
55

6-
# pyre-unsafe
76

87
from typing import Set, Type
98

backends/arm/_passes/convert_split_to_slice.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,6 @@
33
# This source code is licensed under the BSD-style license found in the
44
# LICENSE file in the root directory of this source tree.
55

6-
# pyre-unsafe
76

87
from typing import Set, Type
98

0 commit comments

Comments
 (0)