Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions nncf/experimental/torch/fx/quantization/quantize_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,6 @@

import torch
import torch.fx
from torch.ao.quantization.pt2e.duplicate_dq_pass import DuplicateDQPass
from torch.ao.quantization.pt2e.port_metadata_pass import PortNodeMetaForQDQ
from torch.ao.quantization.pt2e.qat_utils import _fold_conv_bn_qat
from torch.ao.quantization.pt2e.utils import _disallow_eval_train
Expand All @@ -27,6 +26,7 @@
from nncf.common.quantization.structs import QuantizationPreset
from nncf.data import Dataset
from nncf.experimental.torch.fx.quantization.backend_parameters import is_weight_compression_needed
from nncf.experimental.torch.fx.transformations import DuplicateDQPassNoAnnotations
from nncf.experimental.torch.fx.transformations import apply_quantization_transformations
from nncf.experimental.torch.fx.transformations import compress_post_quantize_transformation
from nncf.experimental.torch.fx.transformations import fq_weights_transformation
Expand Down Expand Up @@ -102,7 +102,7 @@ def quantize_impl(
quantized_model = GraphModule(quantized_model, quantized_model.graph)

quantized_model = _fold_conv_bn_qat(quantized_model)
pm = PassManager([DuplicateDQPass()])
pm = PassManager([DuplicateDQPassNoAnnotations()])

quantized_model = pm(quantized_model).graph_module
pm = PassManager([PortNodeMetaForQDQ()])
Expand Down
4 changes: 2 additions & 2 deletions nncf/experimental/torch/fx/quantization/quantize_pt2e.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,6 @@

import torch
import torch.fx
from torch.ao.quantization.pt2e.duplicate_dq_pass import DuplicateDQPass
from torch.ao.quantization.pt2e.port_metadata_pass import PortNodeMetaForQDQ
from torch.ao.quantization.pt2e.utils import _disallow_eval_train
from torch.ao.quantization.pt2e.utils import _fuse_conv_bn_
Expand All @@ -33,6 +32,7 @@
from nncf.experimental.torch.fx.quantization.quantizer.openvino_quantizer import OpenVINOQuantizer
from nncf.experimental.torch.fx.quantization.quantizer.torch_ao_adapter import TorchAOQuantizerAdapter
from nncf.experimental.torch.fx.transformations import QUANTIZE_NODE_TARGETS
from nncf.experimental.torch.fx.transformations import DuplicateDQPassNoAnnotations
from nncf.experimental.torch.fx.transformations import compress_post_quantize_transformation
from nncf.quantization.advanced_parameters import AdvancedBiasCorrectionParameters
from nncf.quantization.advanced_parameters import AdvancedSmoothQuantParameters
Expand Down Expand Up @@ -132,7 +132,7 @@ def quantize_pt2e(
else:
constant_fold(quantized_model, _quant_node_constraint)

pm = PassManager([DuplicateDQPass()])
pm = PassManager([DuplicateDQPassNoAnnotations()])

quantized_model = pm(quantized_model).graph_module
pm = PassManager([PortNodeMetaForQDQ()])
Expand Down
85 changes: 85 additions & 0 deletions nncf/experimental/torch/fx/transformations.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,13 +9,17 @@
# See the License for the specific language governing permissions and
# limitations under the License.

import operator
from copy import copy
from typing import Any, Callable, Optional, Union

import torch
import torch.fx
from torch.ao.quantization.fx.utils import create_getattr_from_value
from torch.ao.quantization.pt2e.utils import _fuse_conv_bn_
from torch.fx.node import map_arg
from torch.fx.passes.infra.pass_base import PassBase
from torch.fx.passes.infra.pass_base import PassResult
from torch.quantization.fake_quantize import FakeQuantize

import nncf
Expand Down Expand Up @@ -741,3 +745,84 @@ def constraint_fn(node: torch.fx.Node):
return node.op != "call_function" or node.target not in QUANTIZE_NODE_TARGETS + DEQUANTIZE_NODE_TARGETS

constant_fold(model, constraint_fn=constraint_fn)


def _duplicate_dq(gm: torch.fx.GraphModule, dq_node: torch.fx.Node, user: torch.fx.Node) -> None:
"""
Duplicates the given dequantizer node so that the specified user node
has a unique instance of the dequantizer.

:param gm: torch.fx.GraphModule instance.
:param dq_node: The original dequantizer node to be duplicated.
:param user: The user node that requires a unique dequantizer node instance.
"""
with gm.graph.inserting_after(dq_node):
new_node = gm.graph.node_copy(dq_node)

def maybe_replace_node(n: torch.fx.Node) -> torch.fx.Node:
if n == dq_node:
return new_node
return n

new_args = map_arg(user.args, maybe_replace_node)
new_kwargs = map_arg(user.kwargs, maybe_replace_node)
user.args = new_args
user.kwargs = new_kwargs


def _is_sym_size_node(node: torch.fx.Node) -> bool:
"""
Returns True if the given node is a sym size node instance, False otherwise.

:param node: The given torch.fx.Node.
:return: True if the given node is a sym size node instance, False otherwise.
"""
return (
node.op == "call_function"
and node.target == torch.ops.aten.sym_size.default
or node.target == torch.ops.aten.sym_numel.default
or node.target == torch.ops.aten.sym_numel
or node.target == torch.ops.aten.sym_size
)


class DuplicateDQPassNoAnnotations(PassBase):
"""
Pass that duplicates Dequantizer (DQ) nodes so that each user node has a unique instance,
but only when the DQ node retrieves its parameters via a `getitem` operation with constants.
"""

def call(self, graph_module: torch.fx.GraphModule) -> PassResult:
"""
Invokes the pass.

:param graph_module: The FX GraphModule to be transformed.
:return: PassResult containing the modified graph module and a success flag.
"""
for node in graph_module.graph.nodes:
if node.op == "call_function" and node.target in DEQUANTIZE_NODE_TARGETS:
dq_users = list(filter((lambda x: not _is_sym_size_node(x)), node.users))
if len(dq_users) <= 1:
continue
# Do not duplicate dq for dynamic quantization
# Pattern: choose_qparam - getitem - q - dq
q_node = node.args[0]
if q_node.op == "call_function" and q_node.target in QUANTIZE_NODE_TARGETS:
getitem_node = q_node.args[1]
if (
isinstance(getitem_node, torch.fx.node.Node)
and getitem_node.op == "call_function"
and getitem_node.target == operator.getitem
):
choose_qparam_node = getitem_node.args[0]
if (
isinstance(choose_qparam_node, torch.fx.node.Node)
and choose_qparam_node.op == "call_function"
and choose_qparam_node.target == torch.ops.quantized_decomposed.choose_qparams.tensor
):
continue
for user in dq_users:
_duplicate_dq(graph_module, node, user)
graph_module.graph.eliminate_dead_code()
graph_module.recompile()
return PassResult(graph_module, True)
Loading
Loading