Skip to content

Commit fc239a7

Browse files
DuplicateDQPass is replaced by DuplicateDQPassNoAnnotations
1 parent a0ced9a commit fc239a7

27 files changed

+13051
-12761
lines changed

nncf/experimental/torch/fx/quantization/quantize_pt2e.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,6 @@
1414

1515
import torch
1616
import torch.fx
17-
from torch.ao.quantization.pt2e.duplicate_dq_pass import DuplicateDQPass
1817
from torch.ao.quantization.pt2e.port_metadata_pass import PortNodeMetaForQDQ
1918
from torch.ao.quantization.pt2e.utils import _disallow_eval_train
2019
from torch.ao.quantization.pt2e.utils import _fuse_conv_bn_
@@ -33,6 +32,7 @@
3332
from nncf.experimental.torch.fx.quantization.quantizer.openvino_quantizer import OpenVINOQuantizer
3433
from nncf.experimental.torch.fx.quantization.quantizer.torch_ao_adapter import TorchAOQuantizerAdapter
3534
from nncf.experimental.torch.fx.transformations import QUANTIZE_NODE_TARGETS
35+
from nncf.experimental.torch.fx.transformations import DuplicateDQPassNoAnnotations
3636
from nncf.experimental.torch.fx.transformations import compress_post_quantize_transformation
3737
from nncf.quantization.advanced_parameters import AdvancedBiasCorrectionParameters
3838
from nncf.quantization.advanced_parameters import AdvancedSmoothQuantParameters
@@ -132,7 +132,7 @@ def quantize_pt2e(
132132
else:
133133
constant_fold(quantized_model, _quant_node_constraint)
134134

135-
pm = PassManager([DuplicateDQPass()])
135+
pm = PassManager([DuplicateDQPassNoAnnotations()])
136136

137137
quantized_model = pm(quantized_model).graph_module
138138
pm = PassManager([PortNodeMetaForQDQ()])

nncf/experimental/torch/fx/transformations.py

Lines changed: 66 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,13 +9,17 @@
99
# See the License for the specific language governing permissions and
1010
# limitations under the License.
1111

12+
import operator
1213
from copy import copy
1314
from typing import Any, Callable, Optional, Union
1415

1516
import torch
1617
import torch.fx
1718
from torch.ao.quantization.fx.utils import create_getattr_from_value
1819
from torch.ao.quantization.pt2e.utils import _fuse_conv_bn_
20+
from torch.fx.node import map_arg
21+
from torch.fx.passes.infra.pass_base import PassBase
22+
from torch.fx.passes.infra.pass_base import PassResult
1923
from torch.quantization.fake_quantize import FakeQuantize
2024

2125
import nncf
@@ -741,3 +745,65 @@ def constraint_fn(node: torch.fx.Node):
741745
return node.op != "call_function" or node.target not in QUANTIZE_NODE_TARGETS + DEQUANTIZE_NODE_TARGETS
742746

743747
constant_fold(model, constraint_fn=constraint_fn)
748+
749+
750+
def _duplicate_dq(gm: torch.fx.GraphModule, dq_node: torch.fx.Node, user: torch.fx.Node):
751+
with gm.graph.inserting_after(dq_node):
752+
new_node = gm.graph.node_copy(dq_node)
753+
754+
def maybe_replace_node(n: torch.fx.Node) -> torch.fx.Node:
755+
if n == dq_node:
756+
return new_node
757+
else:
758+
return n
759+
760+
new_args = map_arg(user.args, maybe_replace_node)
761+
new_kwargs = map_arg(user.kwargs, maybe_replace_node)
762+
user.args = new_args
763+
user.kwargs = new_kwargs
764+
765+
766+
def _is_sym_size_node(node: torch.fx.Node):
767+
return (
768+
node.op == "call_function"
769+
and node.target == torch.ops.aten.sym_size.default
770+
or node.target == torch.ops.aten.sym_numel.default
771+
or node.target == torch.ops.aten.sym_numel
772+
or node.target == torch.ops.aten.sym_size
773+
)
774+
775+
776+
def _filter_sym_size_users(node: torch.fx.Node) -> list[torch.fx.Node]:
777+
node_users = list(filter((lambda x: (_is_sym_size_node(x) is False)), node.users))
778+
return node_users
779+
780+
781+
class DuplicateDQPassNoAnnotations(PassBase):
782+
def call(self, graph_module: torch.fx.GraphModule) -> PassResult:
783+
for node in graph_module.graph.nodes:
784+
if node.op == "call_function" and node.target in DEQUANTIZE_NODE_TARGETS:
785+
dq_users = _filter_sym_size_users(node)
786+
if len(dq_users) <= 1:
787+
continue
788+
# Do not duplicate dq for dynamic quantization
789+
# Pattern: choose_qparam - getitem - q - dq
790+
q_node = node.args[0]
791+
if q_node.op == "call_function" and q_node.target in QUANTIZE_NODE_TARGETS:
792+
getitem_node = q_node.args[1]
793+
if (
794+
isinstance(getitem_node, torch.fx.node.Node)
795+
and getitem_node.op == "call_function"
796+
and getitem_node.target == operator.getitem
797+
):
798+
choose_qparam_node = getitem_node.args[0]
799+
if (
800+
isinstance(choose_qparam_node, torch.fx.node.Node)
801+
and choose_qparam_node.op == "call_function"
802+
and choose_qparam_node.target == torch.ops.quantized_decomposed.choose_qparams.tensor
803+
):
804+
continue
805+
for user in dq_users:
806+
_duplicate_dq(graph_module, node, user)
807+
graph_module.graph.eliminate_dead_code()
808+
graph_module.recompile()
809+
return PassResult(graph_module, True)

0 commit comments

Comments
 (0)