Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 0 additions & 1 deletion backends/xnnpack/_passes/TARGETS
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@ python_library(
"//caffe2:torch",
"//executorch/backends/transforms:addmm_mm_to_linear",
"//executorch/backends/transforms:lib",
"//executorch/backends/xnnpack/partition:configs",
"//executorch/backends/xnnpack/partition:partitioner_graphs",
"//executorch/backends/xnnpack/serialization:xnnpack_schema",
"//executorch/backends/xnnpack/utils:xnnpack_utils",
Expand Down
4 changes: 0 additions & 4 deletions backends/xnnpack/_passes/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,9 +25,6 @@
FuseBatchNormWithConvPass,
)
from executorch.backends.xnnpack._passes.prelu_reshape_pass import PReLUReshapePass
from executorch.backends.xnnpack._passes.tag_implicit_q_dq_pass import (
TagImplicitQDqPass,
)
from executorch.backends.xnnpack._passes.xnnpack_pass import XNNPACKPass

from executorch.exir.pass_base import ExportPass
Expand Down Expand Up @@ -70,7 +67,6 @@ def __init__(
Conv1dUnsqueezePass,
PReLUReshapePass,
ChannelsLastTaggedReshapePass,
TagImplicitQDqPass,
]
else:
self.passes = passes
Expand Down
36 changes: 30 additions & 6 deletions backends/xnnpack/_passes/channels_last_tagged_reshape_pass.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,12 @@

import torch
from executorch.backends.xnnpack._passes.xnnpack_pass import XNNPACKPass
from executorch.backends.xnnpack.utils.quant_utils import is_dynamic_qdq
from executorch.backends.xnnpack.utils.quant_utils import (
is_dequant,
is_dynamic_qdq,
is_tagged_as_implicit_q_dq,
tag_as_implicit_q_dq,
)
from executorch.backends.xnnpack.utils.utils import is_param_node
from executorch.exir.dialects._ops import ops as exir_ops
from executorch.exir.pass_base import PassResult
Expand Down Expand Up @@ -144,17 +149,32 @@ def insert_copy_q_dq(
target=exir_ops.edge.quantized_decomposed.quantize_per_tensor.default,
args=(copy,) + q_params,
)
q.meta = copy.meta
q.meta = copy.meta.copy()

with graph_module.graph.inserting_after(q):
dq = self.create_call_function_node(
graph_module=graph_module,
target=exir_ops.edge.quantized_decomposed.dequantize_per_tensor.default,
args=(q,) + q_params,
)
dq.meta = q.meta
dq.meta = q.meta.copy()

after.replace_input_with(before, dq)
# Always tag q as implicit
tag_as_implicit_q_dq(q)

# Tag relevant q/ dq nodes
# Ex: Original: G = conv -> q1 (Tag) -> dq1 (No Tag) -> output
# Insert (copy q dq pattern), G = conv -> q1 -> dq1 -> (copy q2 dq2)-> output
# if dq1 is not tagged as implicit, then tag dq2 and swap the dq1 and dq2 to simulate
# the pattern: G = conv -> q1 (Tag) -> (dq2 (Tag) copy q2 (Tag))-> dq1 (No Tag) -> output

if is_dequant(before) and is_tagged_as_implicit_q_dq(before):
tag_as_implicit_q_dq(dq)
if is_dequant(before):
tag_as_implicit_q_dq(before)

before.replace_all_uses_with(dq)
copy.replace_input_with(dq, before)

def insert_dq_copy_q(
self,
Expand All @@ -170,15 +190,19 @@ def insert_dq_copy_q(
target=exir_ops.edge.quantized_decomposed.dequantize_per_tensor.default,
args=(before,) + q_params,
)
dq.meta = before.meta
dq.meta = before.meta.copy()

with graph_module.graph.inserting_after(copy):
q = self.create_call_function_node(
graph_module=graph_module,
target=exir_ops.edge.quantized_decomposed.quantize_per_tensor.default,
args=(copy,) + q_params,
)
q.meta = copy.meta
q.meta = copy.meta.copy()

# Always tag q/dq as implicit
tag_as_implicit_q_dq(dq)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I was talking with @mcr229 about this, this can be brittle i.e. implicit tagging for q/dq in >1 places. Do we want to add a pass in the future to verify if the tagging is correct or not. Just for validation? Feel free to create an issue, don't have to add it right away.

tag_as_implicit_q_dq(q)

copy.replace_input_with(before, dq)
after.replace_input_with(before, q)
Expand Down
16 changes: 13 additions & 3 deletions backends/xnnpack/_passes/conv1d_unsqueeze_pass.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,11 @@

import torch
from executorch.backends.xnnpack._passes.xnnpack_pass import XNNPACKPass
from executorch.backends.xnnpack.utils.quant_utils import is_dequant, is_quant
from executorch.backends.xnnpack.utils.quant_utils import (
is_dequant,
is_quant,
tag_as_implicit_q_dq,
)
from executorch.backends.xnnpack.utils.utils import get_param_tensor, is_param_node
from executorch.exir.dialects._ops import ops as exir_ops
from executorch.exir.pass_base import PassResult
Expand Down Expand Up @@ -51,15 +55,21 @@ def insert_q_dq_pair(
op_target=exir_ops.edge.quantized_decomposed.quantize_per_tensor.default,
args=(), # We add the argument last
)
q.meta = anchor.meta
q.meta = anchor.meta.copy()

# Tag q as implicit
tag_as_implicit_q_dq(q)

with graph.inserting_after(q):
dq = self.create_node(
graph=graph,
op_target=exir_ops.edge.quantized_decomposed.dequantize_per_tensor.default,
args=(q,) + q_params,
)
dq.meta = q.meta
dq.meta = q.meta.copy()

# Tag dq as implicit
tag_as_implicit_q_dq(dq)

anchor.replace_all_uses_with(dq)
# We add this last so the replace all uses above does not replace the quqntized
Expand Down
8 changes: 7 additions & 1 deletion backends/xnnpack/_passes/decompose_cat.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,11 @@
import logging

import torch
from executorch.backends.xnnpack.utils.quant_utils import is_dequant, is_quant
from executorch.backends.xnnpack.utils.quant_utils import (
is_dequant,
is_quant,
tag_as_implicit_q_dq,
)
from executorch.exir.dialects._ops import ops as exir_ops

from executorch.exir.pass_base import ExportPass, PassResult
Expand Down Expand Up @@ -79,13 +83,15 @@ def call(self, graph_module: torch.fx.GraphModule):
args=(node,) + q_params,
kwargs=q_kwargs,
)
tag_as_implicit_q_dq(q_node)
with gm.graph.inserting_after(q_node):
dq_node = gm.graph.create_node(
"call_function",
target=exir_ops.edge.quantized_decomposed.dequantize_per_tensor.default,
args=(q_node,) + q_params,
kwargs=q_kwargs,
)
tag_as_implicit_q_dq(dq_node)
remainder_concat_node.args = (
[dq_node] + remainder_nodes_to_concat,
) + node.args[1:]
Expand Down
217 changes: 0 additions & 217 deletions backends/xnnpack/_passes/tag_implicit_q_dq_pass.py

This file was deleted.

Loading
Loading