Skip to content

Commit 38d7fc0

Browse files
Deprecate tag qdq pass in xnnbackend (#12170)
Summary: This diff decentralizes the q/dq implicit node tagging to individual partition configs instead of tagging it as part of backend pass. Changes in this diff: 1. Deprecate tag q dq pass 2. Remove all the places where this pass is used in the backend preprocess phase. 3. Decentralize the tagging to individual configs a. `generic_node_configs` will handle most of the non gemm nodes b. `gemm_configs` will handle gemm nodes c. channels last pass will add (copy q dq) or (dq copy q), tag the relevant nodes. d. tag q dq in conv1d unsqueeze pass. e. Tag q dq in compose cat 4. Deprecate configs.py where all the collection of nodes is maintained Fixes: #11588 Reviewed By: mcr229 Differential Revision: D77055623
1 parent 3d90515 commit 38d7fc0

16 files changed

+178
-517
lines changed

backends/xnnpack/_passes/TARGETS

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,6 @@ python_library(
99
"//caffe2:torch",
1010
"//executorch/backends/transforms:addmm_mm_to_linear",
1111
"//executorch/backends/transforms:lib",
12-
"//executorch/backends/xnnpack/partition:configs",
1312
"//executorch/backends/xnnpack/partition:partitioner_graphs",
1413
"//executorch/backends/xnnpack/serialization:xnnpack_schema",
1514
"//executorch/backends/xnnpack/utils:xnnpack_utils",

backends/xnnpack/_passes/__init__.py

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -25,9 +25,6 @@
2525
FuseBatchNormWithConvPass,
2626
)
2727
from executorch.backends.xnnpack._passes.prelu_reshape_pass import PReLUReshapePass
28-
from executorch.backends.xnnpack._passes.tag_implicit_q_dq_pass import (
29-
TagImplicitQDqPass,
30-
)
3128
from executorch.backends.xnnpack._passes.xnnpack_pass import XNNPACKPass
3229

3330
from executorch.exir.pass_base import ExportPass
@@ -70,7 +67,6 @@ def __init__(
7067
Conv1dUnsqueezePass,
7168
PReLUReshapePass,
7269
ChannelsLastTaggedReshapePass,
73-
TagImplicitQDqPass,
7470
]
7571
else:
7672
self.passes = passes

backends/xnnpack/_passes/channels_last_tagged_reshape_pass.py

Lines changed: 30 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,12 @@
88

99
import torch
1010
from executorch.backends.xnnpack._passes.xnnpack_pass import XNNPACKPass
11-
from executorch.backends.xnnpack.utils.quant_utils import is_dynamic_qdq
11+
from executorch.backends.xnnpack.utils.quant_utils import (
12+
is_dequant,
13+
is_dynamic_qdq,
14+
is_tagged_as_implicit_q_dq,
15+
tag_as_implicit_q_dq,
16+
)
1217
from executorch.backends.xnnpack.utils.utils import is_param_node
1318
from executorch.exir.dialects._ops import ops as exir_ops
1419
from executorch.exir.pass_base import PassResult
@@ -144,17 +149,32 @@ def insert_copy_q_dq(
144149
target=exir_ops.edge.quantized_decomposed.quantize_per_tensor.default,
145150
args=(copy,) + q_params,
146151
)
147-
q.meta = copy.meta
152+
q.meta = copy.meta.copy()
148153

149154
with graph_module.graph.inserting_after(q):
150155
dq = self.create_call_function_node(
151156
graph_module=graph_module,
152157
target=exir_ops.edge.quantized_decomposed.dequantize_per_tensor.default,
153158
args=(q,) + q_params,
154159
)
155-
dq.meta = q.meta
160+
dq.meta = q.meta.copy()
156161

157-
after.replace_input_with(before, dq)
162+
# Always tag q as implicit
163+
tag_as_implicit_q_dq(q)
164+
165+
# Tag relevant q/ dq nodes
166+
# Ex: Original: G = conv -> q1 (Tag) -> dq1 (No Tag) -> output
167+
# Insert (copy q dq pattern), G = conv -> q1 -> dq1 -> (copy q2 dq2)-> output
168+
# if dq1 is not tagged as implicit, then tag dq2 and swap the dq1 and dq2 to simulate
169+
# the pattern: G = conv -> q1 (Tag) -> (dq2 (Tag) copy q2 (Tag))-> dq1 (No Tag) -> output
170+
171+
if is_dequant(before) and is_tagged_as_implicit_q_dq(before):
172+
tag_as_implicit_q_dq(dq)
173+
if is_dequant(before):
174+
tag_as_implicit_q_dq(before)
175+
176+
before.replace_all_uses_with(dq)
177+
copy.replace_input_with(dq, before)
158178

159179
def insert_dq_copy_q(
160180
self,
@@ -170,15 +190,19 @@ def insert_dq_copy_q(
170190
target=exir_ops.edge.quantized_decomposed.dequantize_per_tensor.default,
171191
args=(before,) + q_params,
172192
)
173-
dq.meta = before.meta
193+
dq.meta = before.meta.copy()
174194

175195
with graph_module.graph.inserting_after(copy):
176196
q = self.create_call_function_node(
177197
graph_module=graph_module,
178198
target=exir_ops.edge.quantized_decomposed.quantize_per_tensor.default,
179199
args=(copy,) + q_params,
180200
)
181-
q.meta = copy.meta
201+
q.meta = copy.meta.copy()
202+
203+
# Always tag q/dq as implicit
204+
tag_as_implicit_q_dq(dq)
205+
tag_as_implicit_q_dq(q)
182206

183207
copy.replace_input_with(before, dq)
184208
after.replace_input_with(before, q)

backends/xnnpack/_passes/conv1d_unsqueeze_pass.py

Lines changed: 13 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,11 @@
88

99
import torch
1010
from executorch.backends.xnnpack._passes.xnnpack_pass import XNNPACKPass
11-
from executorch.backends.xnnpack.utils.quant_utils import is_dequant, is_quant
11+
from executorch.backends.xnnpack.utils.quant_utils import (
12+
is_dequant,
13+
is_quant,
14+
tag_as_implicit_q_dq,
15+
)
1216
from executorch.backends.xnnpack.utils.utils import get_param_tensor, is_param_node
1317
from executorch.exir.dialects._ops import ops as exir_ops
1418
from executorch.exir.pass_base import PassResult
@@ -51,15 +55,21 @@ def insert_q_dq_pair(
5155
op_target=exir_ops.edge.quantized_decomposed.quantize_per_tensor.default,
5256
args=(), # We add the argument last
5357
)
54-
q.meta = anchor.meta
58+
q.meta = anchor.meta.copy()
59+
60+
# Tag q as implicit
61+
tag_as_implicit_q_dq(q)
5562

5663
with graph.inserting_after(q):
5764
dq = self.create_node(
5865
graph=graph,
5966
op_target=exir_ops.edge.quantized_decomposed.dequantize_per_tensor.default,
6067
args=(q,) + q_params,
6168
)
62-
dq.meta = q.meta
69+
dq.meta = q.meta.copy()
70+
71+
# Tag dq as implicit
72+
tag_as_implicit_q_dq(dq)
6373

6474
anchor.replace_all_uses_with(dq)
6575
# We add this last so the replace all uses above does not replace the quqntized

backends/xnnpack/_passes/decompose_cat.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,11 @@
77
import logging
88

99
import torch
10-
from executorch.backends.xnnpack.utils.quant_utils import is_dequant, is_quant
10+
from executorch.backends.xnnpack.utils.quant_utils import (
11+
is_dequant,
12+
is_quant,
13+
tag_as_implicit_q_dq,
14+
)
1115
from executorch.exir.dialects._ops import ops as exir_ops
1216

1317
from executorch.exir.pass_base import ExportPass, PassResult
@@ -79,13 +83,15 @@ def call(self, graph_module: torch.fx.GraphModule):
7983
args=(node,) + q_params,
8084
kwargs=q_kwargs,
8185
)
86+
tag_as_implicit_q_dq(q_node)
8287
with gm.graph.inserting_after(q_node):
8388
dq_node = gm.graph.create_node(
8489
"call_function",
8590
target=exir_ops.edge.quantized_decomposed.dequantize_per_tensor.default,
8691
args=(q_node,) + q_params,
8792
kwargs=q_kwargs,
8893
)
94+
tag_as_implicit_q_dq(dq_node)
8995
remainder_concat_node.args = (
9096
[dq_node] + remainder_nodes_to_concat,
9197
) + node.args[1:]

backends/xnnpack/_passes/tag_implicit_q_dq_pass.py

Lines changed: 0 additions & 217 deletions
This file was deleted.

0 commit comments

Comments
 (0)