Skip to content

Commit 37d5404

Browse files
Martin LindströmMartin Lindström
authored andcommitted
Arm backend: Merge RetraceFoldedDtypesPass into FoldAndAnnotateQParamsPass
The pass RetraceFoldedDtypesPass carries out extra processing after the output of FoldAndAnnotateQParamsPass, meaning that they are tightly coupled and always run in sequence. Merge these two passes together into FoldAndAnnotateQParamsPass. Signed-off-by: Martin Lindström <[email protected]> Change-Id: I1e68a4b87cef2778623fbac6a68a62abf5764abb
1 parent de56c81 commit 37d5404

File tree

3 files changed

+11
-38
lines changed

3 files changed

+11
-38
lines changed

backends/arm/_passes/__init__.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -72,7 +72,6 @@
7272
from .fold_qdq_with_annotated_qparams_pass import ( # noqa
7373
FoldAndAnnotateQParamsPass,
7474
QuantizeOperatorArguments,
75-
RetraceFoldedDtypesPass,
7675
)
7776
from .fuse_batchnorm2d_pass import FuseBatchnorm2DPass # noqa
7877
from .fuse_constant_ops_pass import ComputeConstantOpsAOT, FuseConstantArgsPass # noqa

backends/arm/_passes/arm_pass_manager.py

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -88,7 +88,6 @@
8888
RemoveNoopPass,
8989
ReplaceInfValues,
9090
ReplaceScalarWithTensorByProfilePass,
91-
RetraceFoldedDtypesPass,
9291
RewriteConv2dPass,
9392
RewriteMatmulPass,
9493
RewriteUpsamplePass,
@@ -176,7 +175,6 @@ def _tosa_INT_pipeline(self, exported_program: ExportedProgram) -> GraphModule:
176175
self.add_pass(QuantizeOperatorArguments())
177176
self.add_pass(ConvertELUParamsPass())
178177
self.add_pass(FoldAndAnnotateQParamsPass(exported_program)) # type: ignore[call-arg]
179-
self.add_pass(RetraceFoldedDtypesPass())
180178
self.add_pass(UnsqueezeScalarPlaceholdersPass(exported_program))
181179
self.add_pass(MatchArgRanksPass(exported_program))
182180
if self.tosa_spec.is_U55_subset:
@@ -271,7 +269,6 @@ def _tosa_FP_pipeline(self, exported_program: ExportedProgram) -> GraphModule:
271269
self.add_pass(AnnotateDecomposedMatmulPass())
272270
self.add_pass(QuantizeOperatorArguments())
273271
self.add_pass(FoldAndAnnotateQParamsPass(exported_program)) # type: ignore[call-arg]
274-
self.add_pass(RetraceFoldedDtypesPass())
275272
self.add_pass(UnsqueezeScalarPlaceholdersPass(exported_program))
276273
self.add_pass(MatchArgRanksPass(exported_program))
277274
self.add_pass(DecomposeAdaptiveAvgPool2dPass())

backends/arm/_passes/fold_qdq_with_annotated_qparams_pass.py

Lines changed: 11 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
from executorch.backends.arm._passes.arm_pass_utils import (
1414
get_param_tensor,
1515
is_param_node,
16+
set_node_arg,
1617
)
1718
from executorch.backends.arm._passes.insert_table_ops import InsertTableOpsPass
1819

@@ -22,7 +23,6 @@
2223
from executorch.exir import ExportedProgram
2324

2425
from executorch.exir.dialects._ops import ops as exir_ops
25-
from executorch.exir.dialects.edge._ops import EdgeOpOverload
2626

2727
from executorch.exir.pass_base import ExportPass, PassResult
2828
from torch.fx import GraphModule, Node
@@ -66,38 +66,6 @@ def get_output_qparams(node: Node) -> dict[int, QuantArgs]:
6666
return output_qparams
6767

6868

69-
class RetraceFoldedDtypesPass(ArmPass):
70-
"""
71-
FoldAndAnnotateQParamsPass folds dq and q nodes. When the graph is retraced
72-
some operators are retraced to types that cannot be handled by TOSA. One
73-
such example is sum.dim_IntList:
74-
q (int8) -> dq (fp32) -> sum (fp32) -> q (int8) ...
75-
After folding it becomes:
76-
q (int8) -> sum (int64) -> ...
77-
This pass changes types of ops in self.targeted_ops, such as sum, so that
78-
the output type of that matches the type of the output_qparams.
79-
"""
80-
81-
_passes_required_after: Set[Type[ExportPass]] = set()
82-
83-
targeted_ops: Set[EdgeOpOverload] = {
84-
exir_ops.edge.aten.sum.dim_IntList,
85-
}
86-
87-
def call_operator(self, op, args, kwargs, meta):
88-
if op not in self.targeted_ops:
89-
return super().call_operator(op, args, kwargs, meta, False)
90-
91-
node_kwargs = kwargs.copy()
92-
output_qparams = meta["output_qparams"]
93-
if len(output_qparams) == 0:
94-
return super().call_operator(op, args, kwargs, meta, False)
95-
96-
output_dtype = output_qparams[0].dtype
97-
node_kwargs["dtype"] = output_dtype
98-
return super().call_operator(op, args, node_kwargs, meta, True)
99-
100-
10169
class FoldAndAnnotateQParamsPass(ArmPass):
10270
"""
10371
A pass that walks the graph and removes any DQ and Q nodes before and after the target
@@ -129,7 +97,6 @@ class FoldAndAnnotateQParamsPass(ArmPass):
12997
"""
13098

13199
_passes_required_after: Set[Type[ExportPass]] = {
132-
RetraceFoldedDtypesPass,
133100
InsertTableOpsPass,
134101
RemoveNoopPass,
135102
}
@@ -234,6 +201,16 @@ def call(self, graph_module: GraphModule) -> PassResult:
234201
user.replace_all_uses_with(n)
235202
graph_module.graph.erase_node(user)
236203

204+
# Some op(s) contain a "dtype" key in their node kwargs. Set this
205+
# to the type of output qparams.
206+
output_qparams = n.meta["output_qparams"]
207+
if (
208+
n.target in {exir_ops.edge.aten.sum.dim_IntList}
209+
and len(output_qparams) > 0
210+
):
211+
output_dtype = output_qparams[0].dtype
212+
set_node_arg(n, "dtype", output_dtype)
213+
237214
# retrace the graph to update the fake tensor types
238215
graph_module = super().call(graph_module).graph_module
239216

0 commit comments

Comments
 (0)