|
13 | 13 | from executorch.backends.arm._passes.arm_pass_utils import ( |
14 | 14 | get_param_tensor, |
15 | 15 | is_param_node, |
16 | | - set_node_arg, |
17 | 16 | ) |
18 | 17 | from executorch.backends.arm._passes.insert_table_ops import InsertTableOpsPass |
19 | 18 |
|
|
23 | 22 | from executorch.exir import ExportedProgram |
24 | 23 |
|
25 | 24 | from executorch.exir.dialects._ops import ops as exir_ops |
| 25 | +from executorch.exir.dialects.edge._ops import EdgeOpOverload |
26 | 26 |
|
27 | 27 | from executorch.exir.pass_base import ExportPass, PassResult |
28 | 28 | from torch.fx import GraphModule, Node |
@@ -66,6 +66,38 @@ def get_output_qparams(node: Node) -> dict[int, QuantArgs]: |
66 | 66 | return output_qparams |
67 | 67 |
|
68 | 68 |
|
| 69 | +class RetraceFoldedDtypesPass(ArmPass): |
| 70 | + """ |
| 71 | + FoldAndAnnotateQParamsPass folds dq and q nodes. When the graph is retraced |
| 72 | + some operators are retraced to types that cannot be handled by TOSA. One |
| 73 | + such example is sum.dim_IntList: |
| 74 | + q (int8) -> dq (fp32) -> sum (fp32) -> q (int8) ... |
| 75 | + After folding it becomes: |
| 76 | + q (int8) -> sum (int64) -> ... |
| 77 | + This pass changes types of ops in self.targeted_ops, such as sum, so that |
| 78 | + the output type of that matches the type of the output_qparams. |
| 79 | + """ |
| 80 | + |
| 81 | + _passes_required_after: Set[Type[ExportPass]] = set() |
| 82 | + |
| 83 | + targeted_ops: Set[EdgeOpOverload] = { |
| 84 | + exir_ops.edge.aten.sum.dim_IntList, |
| 85 | + } |
| 86 | + |
| 87 | + def call_operator(self, op, args, kwargs, meta): |
| 88 | + if op not in self.targeted_ops: |
| 89 | + return super().call_operator(op, args, kwargs, meta, False) |
| 90 | + |
| 91 | + node_kwargs = kwargs.copy() |
| 92 | + output_qparams = meta["output_qparams"] |
| 93 | + if len(output_qparams) == 0: |
| 94 | + return super().call_operator(op, args, kwargs, meta, False) |
| 95 | + |
| 96 | + output_dtype = output_qparams[0].dtype |
| 97 | + node_kwargs["dtype"] = output_dtype |
| 98 | + return super().call_operator(op, args, node_kwargs, meta, True) |
| 99 | + |
| 100 | + |
69 | 101 | class FoldAndAnnotateQParamsPass(ArmPass): |
70 | 102 | """ |
71 | 103 | A pass that walks the graph and removes any DQ and Q nodes before and after the target |
@@ -97,6 +129,7 @@ class FoldAndAnnotateQParamsPass(ArmPass): |
97 | 129 | """ |
98 | 130 |
|
99 | 131 | _passes_required_after: Set[Type[ExportPass]] = { |
| 132 | + RetraceFoldedDtypesPass, |
100 | 133 | InsertTableOpsPass, |
101 | 134 | RemoveNoopPass, |
102 | 135 | } |
@@ -201,16 +234,6 @@ def call(self, graph_module: GraphModule) -> PassResult: |
201 | 234 | user.replace_all_uses_with(n) |
202 | 235 | graph_module.graph.erase_node(user) |
203 | 236 |
|
204 | | - # Some op(s) contain a "dtype" key in their node kwargs. Set this |
205 | | - # to the type of output qparams. |
206 | | - output_qparams = n.meta["output_qparams"] |
207 | | - if ( |
208 | | - n.target in {exir_ops.edge.aten.sum.dim_IntList} |
209 | | - and len(output_qparams) > 0 |
210 | | - ): |
211 | | - output_dtype = output_qparams[0].dtype |
212 | | - set_node_arg(n, "dtype", output_dtype) |
213 | | - |
214 | 237 | # retrace the graph to update the fake tensor types |
215 | 238 | graph_module = super().call(graph_module).graph_module |
216 | 239 |
|
|
0 commit comments