Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
42 changes: 42 additions & 0 deletions src/nncf/experimental/torch/fx/quantization/qdq_parameters.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
# Copyright (c) 2025 Intel Corporation
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# http://www.apache.org/licenses/LICENSE-2.0
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from dataclasses import dataclass

import torch


@dataclass
class TorchQDQParameters:
"""
Stores the quantization parameters required for
creation of a PyTorch quantize-dequantize pair.
:param quant_min: Minimum quant value.
:type quant_min: int
:param quant_max: Maximum quant value.
:type quant_max: int
:param scale: Defines the scale factor used for quantization.
:type scale: torch.Tensor
:param zero_point: Specifies the quantized value to which 0 in floating point maps to.
:type zero_point: torch.Tensor
:param is_per_channel: Whether quantization is applied per channel.
:type is_per_channel: bool
:param ch_axis: Channel axis used for per-channel quantization.
:type ch_axis: int
Comment on lines +23 to +34
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
:param quant_min: Minimum quant value.
:type quant_min: int
:param quant_max: Maximum quant value.
:type quant_max: int
:param scale: Defines the scale factor used for quantization.
:type scale: torch.Tensor
:param zero_point: Specifies the quantized value to which 0 in floating point maps to.
:type zero_point: torch.Tensor
:param is_per_channel: Whether quantization is applied per channel.
:type is_per_channel: bool
:param ch_axis: Channel axis used for per-channel quantization.
:type ch_axis: int
:param quant_min: Minimum quant value.
:param quant_max: Maximum quant value.
:param scale: Defines the scale factor used for quantization.
:param zero_point: Specifies the quantized value to which 0 in floating point maps to.
:param is_per_channel: Whether quantization is applied per channel.
:param ch_axis: Channel axis used for per-channel quantization.

:type in docstring used only for API objects

"""

quant_min: int
quant_max: int
scale: torch.Tensor
zero_point: torch.Tensor
is_per_channel: bool
ch_axis: int
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
from torch.ao.quantization.pt2e.port_metadata_pass import PortNodeMetaForQDQ
from torch.ao.quantization.pt2e.qat_utils import _fold_conv_bn_qat
from torch.ao.quantization.pt2e.utils import _disallow_eval_train
from torch.ao.quantization.pt2e.utils import _fuse_conv_bn_
from torch.fx import GraphModule
from torch.fx.passes.infra.pass_manager import PassManager

Expand All @@ -27,7 +28,6 @@
from nncf.data import Dataset
from nncf.experimental.torch.fx.quantization.backend_parameters import is_weight_compression_needed
from nncf.experimental.torch.fx.transformations import DuplicateDQPassNoAnnotations
from nncf.experimental.torch.fx.transformations import apply_quantization_transformations
from nncf.experimental.torch.fx.transformations import compress_post_quantize_transformation
from nncf.experimental.torch.fx.transformations import fq_weights_transformation
from nncf.parameters import BackupMode
Expand Down Expand Up @@ -87,8 +87,9 @@ def quantize_impl(
advanced_parameters=advanced_parameters,
)

# To make it easier for bias correction algorithms.
apply_quantization_transformations(copied_model)
# Fuse batch norms to convolutions bias
# the same way it done in torchao
_fuse_conv_bn_(copied_model)

nncf_graph = NNCFGraphFactory.create(copied_model)
quantized_model = quantization_algorithm.apply(copied_model, nncf_graph, dataset=calibration_dataset)
Expand Down
15 changes: 11 additions & 4 deletions src/nncf/experimental/torch/fx/quantization/quantize_pt2e.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,10 +14,6 @@

import torch
import torch.fx
from torch.ao.quantization.pt2e.port_metadata_pass import PortNodeMetaForQDQ
from torch.ao.quantization.pt2e.utils import _disallow_eval_train
from torch.ao.quantization.pt2e.utils import _fuse_conv_bn_
from torch.ao.quantization.quantizer import Quantizer
from torch.fx import GraphModule
from torch.fx.passes.infra.pass_manager import PassManager

Expand All @@ -38,6 +34,17 @@
from nncf.quantization.advanced_parameters import AdvancedSmoothQuantParameters
from nncf.quantization.range_estimator import RangeEstimatorParameters

try:
from torchao.quantization.pt2e.quantizer import Quantizer
from torchao.quantization.pt2e.quantizer.port_metadata_pass import PortNodeMetaForQDQ
from torchao.quantization.pt2e.utils import _disallow_eval_train
from torchao.quantization.pt2e.utils import _fuse_conv_bn_
except ImportError:
from torch.ao.quantization.pt2e.port_metadata_pass import PortNodeMetaForQDQ
from torch.ao.quantization.pt2e.utils import _disallow_eval_train
from torch.ao.quantization.pt2e.utils import _fuse_conv_bn_
from torch.ao.quantization.quantizer import Quantizer


@api(canonical_alias="nncf.experimental.torch.fx.quantize_pt2e")
def quantize_pt2e(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,14 +13,6 @@
from typing import Optional, Union

import torch.fx
from torch.ao.quantization.observer import HistogramObserver
from torch.ao.quantization.observer import PerChannelMinMaxObserver
from torch.ao.quantization.quantizer.quantizer import EdgeOrNode
from torch.ao.quantization.quantizer.quantizer import QuantizationAnnotation as TorchAOQuantizationAnnotation
from torch.ao.quantization.quantizer.quantizer import QuantizationSpec as TorchAOQuantizationSpec
from torch.ao.quantization.quantizer.quantizer import QuantizationSpecBase as TorchAOQuantizationSpecBase
from torch.ao.quantization.quantizer.quantizer import Quantizer as TorchAOQuantizer
from torch.ao.quantization.quantizer.quantizer import SharedQuantizationSpec as TorchAOSharedQuantizationSpec

import nncf
from nncf import IgnoredScope
Expand All @@ -43,6 +35,25 @@
from nncf.quantization.algorithms.min_max.algorithm import MinMaxQuantization
from nncf.torch.model_graph_manager import get_weight_tensor_port_ids

try:
from torchao.quantization.pt2e.observer import HistogramObserver
from torchao.quantization.pt2e.observer import PerChannelMinMaxObserver
from torchao.quantization.pt2e.quantizer.quantizer import EdgeOrNode
from torchao.quantization.pt2e.quantizer.quantizer import QuantizationAnnotation as TorchAOQuantizationAnnotation
from torchao.quantization.pt2e.quantizer.quantizer import QuantizationSpec as TorchAOQuantizationSpec
from torchao.quantization.pt2e.quantizer.quantizer import QuantizationSpecBase as TorchAOQuantizationSpecBase
from torchao.quantization.pt2e.quantizer.quantizer import Quantizer as TorchAOQuantizer
from torchao.quantization.pt2e.quantizer.quantizer import SharedQuantizationSpec as TorchAOSharedQuantizationSpec
except ImportError:
from torch.ao.quantization.observer import HistogramObserver
from torch.ao.quantization.observer import PerChannelMinMaxObserver
from torch.ao.quantization.quantizer.quantizer import EdgeOrNode
from torch.ao.quantization.quantizer.quantizer import QuantizationAnnotation as TorchAOQuantizationAnnotation
from torch.ao.quantization.quantizer.quantizer import QuantizationSpec as TorchAOQuantizationSpec
from torch.ao.quantization.quantizer.quantizer import QuantizationSpecBase as TorchAOQuantizationSpecBase
from torch.ao.quantization.quantizer.quantizer import Quantizer as TorchAOQuantizer
from torch.ao.quantization.quantizer.quantizer import SharedQuantizationSpec as TorchAOSharedQuantizationSpec

QUANT_ANNOTATION_KEY = "quantization_annotation"


Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,11 +15,6 @@

import torch
import torch.fx
from torch.ao.quantization.pt2e.prepare import _get_edge_or_node_to_group_id
from torch.ao.quantization.pt2e.prepare import _get_edge_or_node_to_qspec
from torch.ao.quantization.quantizer import Quantizer as TorchAOQuantizer
from torch.ao.quantization.quantizer.quantizer import QuantizationSpec
from torch.ao.quantization.quantizer.quantizer import SharedQuantizationSpec

import nncf
from nncf.common.graph.graph import NNCFGraph
Expand All @@ -34,12 +29,26 @@
from nncf.experimental.torch.fx.nncf_graph_builder import GraphConverter
from nncf.tensor.definitions import TensorDataType

try:
from torchao.quantization.pt2e.prepare import _get_edge_or_node_to_group_id
from torchao.quantization.pt2e.prepare import _get_edge_or_node_to_qspec
from torchao.quantization.pt2e.quantizer import Quantizer as TorchAOQuantizer
from torchao.quantization.pt2e.quantizer.quantizer import QuantizationSpec
from torchao.quantization.pt2e.quantizer.quantizer import SharedQuantizationSpec
except ImportError:
from torch.ao.quantization.pt2e.prepare import _get_edge_or_node_to_group_id
from torch.ao.quantization.pt2e.prepare import _get_edge_or_node_to_qspec
from torch.ao.quantization.quantizer import Quantizer as TorchAOQuantizer
from torch.ao.quantization.quantizer.quantizer import QuantizationSpec
from torch.ao.quantization.quantizer.quantizer import SharedQuantizationSpec


EdgeOrNode = Union[tuple[torch.fx.Node, torch.fx.Node]]


class TorchAOQuantizerAdapter(Quantizer):
"""
Implementation of the NNCF Quantizer interface for any given torch.ao quantizer.
Implementation of the NNCF Quantizer interface for any given torchao quantizer.
"""

def __init__(self, quantizer: TorchAOQuantizer):
Expand Down Expand Up @@ -120,7 +129,7 @@ def _get_node_args(node: torch.fx.Node) -> tuple[Any, ...]:
def get_quantizer_config_from_annotated_model(annotated: torch.fx.GraphModule) -> SingleConfigQuantizerSetup:
"""
Process a torch.fx.GraphModule annotated with quantization specifications
(e.g., via torch.ao observers) and generates a corresponding NNCF quantization setup object,
(e.g., via torchao observers) and generates a corresponding NNCF quantization setup object,
which maps quantization configurations to graph edges.

:param annotated: A torch.fx.GraphModule that has been annotated with Torch quantization observers.
Expand Down Expand Up @@ -149,7 +158,7 @@ def get_quantizer_config_from_annotated_model(annotated: torch.fx.GraphModule) -
if qspec is None:
continue
if not isinstance(qspec, QuantizationSpec):
msg = f"Unknown torch.ao quantization spec: {qspec}"
msg = f"Unknown torchao quantization spec: {qspec}"
raise nncf.InternalError(msg)

if qspec.qscheme in [torch.per_channel_affine, torch.per_channel_symmetric]:
Expand Down
108 changes: 71 additions & 37 deletions src/nncf/experimental/torch/fx/transformations.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,12 +15,9 @@

import torch
import torch.fx
from torch.ao.quantization.fx.utils import create_getattr_from_value
from torch.ao.quantization.pt2e.utils import _fuse_conv_bn_
from torch.fx.node import map_arg
from torch.fx.passes.infra.pass_base import PassBase
from torch.fx.passes.infra.pass_base import PassResult
from torch.quantization.fake_quantize import FakeQuantize

import nncf
import nncf.torch
Expand All @@ -29,6 +26,7 @@
from nncf.experimental.torch.fx.constant_folding import constant_fold
from nncf.experimental.torch.fx.node_utils import get_graph_node_by_name
from nncf.experimental.torch.fx.node_utils import get_tensor_constant_from_node
from nncf.experimental.torch.fx.quantization.qdq_parameters import TorchQDQParameters
from nncf.torch.graph.transformations.commands import PTTargetPoint

TransformationFNType = Callable[[torch.fx.GraphModule], None]
Expand Down Expand Up @@ -223,16 +221,16 @@ def constant_update_fn(


def qdq_insertion_transformation_builder(
quantizer: FakeQuantize, target_points: list[PTTargetPoint]
parameters: TorchQDQParameters, target_points: list[PTTargetPoint]
) -> TransformationFNType:
"""
Returns transformation which inserts quantize-dequantize operations with parameters
inherited from the given quantizer to each given target point.
Returns transformation which inserts quantize-dequantize operations with
the given parameters to each given target point.

:param quantizer: Quantizer module to inherit quantization parameters from.
:param quantizer: Quantization parameters.
:param target_points: List of target point used to insert quantize-dequantize pairs.
:return: Transformation which inserts quantize-dequantize operations with parameters
inherited from the given quantizer to each given target point.
:return: Transformation which inserts quantize-dequantize operations with
the given parameters to each given target point.
"""

def qdq_insertion_transformation(model: torch.fx.GraphModule):
Expand All @@ -243,7 +241,7 @@ def qdq_insertion_transformation(model: torch.fx.GraphModule):
)
raise nncf.InternalError(msg)
for target_point in target_points:
insert_one_qdq(model, target_point, quantizer)
insert_one_qdq(model, target_point, parameters)

return qdq_insertion_transformation

Expand Down Expand Up @@ -311,38 +309,38 @@ def output_insertion_transformation(model: torch.fx.GraphModule):
return output_insertion_transformation


def insert_one_qdq(model: torch.fx.GraphModule, target_point: PTTargetPoint, quantizer: FakeQuantize):
def insert_one_qdq(model: torch.fx.GraphModule, target_point: PTTargetPoint, parameters: TorchQDQParameters):
"""
Inserts quantize-dequantize after the target node to the target model.

:param model: Target model.
:param target_node: Target node, quantizer-dequantizer pair is inserted just after the
target node.
:param quantizer: Quantizer module to inherit quantization parameters from.
:param parameters: Quantization parameters.
"""
# Copied from torch.ao.quantization.quantize_pt2e.convert_pt2e
# Copied from torchao.quantization.quantize_pt2e.convert_pt2e
# 1. extract information for inserting q/dq node from activation_post_process
node_type = "call_function"
quantize_op: Optional[Callable] = None

dtype = torch.int8 if quantizer.quant_min < 0 else torch.uint8
if quantizer.is_per_channel:
dtype = torch.int8 if parameters.quant_min < 0 else torch.uint8
if parameters.is_per_channel:
qparams = {
"_scale_": quantizer.scale,
"_zero_point_": quantizer.zero_point,
"_axis_": quantizer.ch_axis,
"_quant_min_": quantizer.quant_min,
"_quant_max_": quantizer.quant_max,
"_scale_": parameters.scale,
"_zero_point_": parameters.zero_point,
"_axis_": parameters.ch_axis,
"_quant_min_": parameters.quant_min,
"_quant_max_": parameters.quant_max,
"_dtype_": dtype,
}
quantize_op = torch.ops.quantized_decomposed.quantize_per_channel.default
dequantize_op = torch.ops.quantized_decomposed.dequantize_per_channel.default
else:
qparams = {
"_scale_": float(quantizer.scale),
"_zero_point_": int(quantizer.zero_point),
"_quant_min_": quantizer.quant_min,
"_quant_max_": quantizer.quant_max,
"_scale_": float(parameters.scale),
"_zero_point_": int(parameters.zero_point),
"_quant_min_": parameters.quant_min,
"_quant_max_": parameters.quant_max,
"_dtype_": dtype,
}
quantize_op = torch.ops.quantized_decomposed.quantize_per_tensor.default
Expand Down Expand Up @@ -721,19 +719,6 @@ def match_filters(match, original_graph, graph):
_set_meta_for_matches(model, matches)


def apply_quantization_transformations(model: torch.fx.GraphModule) -> None:
"""
Applies quantization transformations to the model.

:param model: Model to apply transformations to.
"""
# BatchNorm operations have 3 output ports,
# to make it easier for algorithms to work
# with the target graph BatchNorm operations
# are being fused
_fuse_conv_bn_(model)


def fold_constant_except_qdq(model: torch.fx.GraphModule):
"""
Performs constant folding avoiding quantize-dequantize pattern.
Expand Down Expand Up @@ -826,3 +811,52 @@ def call(self, graph_module: torch.fx.GraphModule) -> PassResult:
graph_module.graph.eliminate_dead_code()
graph_module.recompile()
return PassResult(graph_module, True)


def get_device(module: torch.nn.Module) -> torch.device:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please reuse

def get_model_device(model: torch.nn.Module) -> torch.device:

"""
Retrieves device of the first parameter of the given module.
If there are no parameters - returns CPU device.

:param module: A torch.nn.Module instance.
:return: A device of the first parameter of the given module.
If there are no parameters - returns CPU device.
"""
try:
named_param = next(module.parameters())
except StopIteration:
named_param = None
if named_param is None:
return torch.device("cpu")
return named_param.device


def create_getattr_from_value(module: torch.nn.Module, graph: torch.fx.Graph, prefix: str, value: Any) -> torch.fx.Node:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not found where value is not a torch.Tensor, is it really need to use Any?

"""
Given a value of any type, creates a getattr node corresponding to the value and
registers the value as a buffer to the module.

:param module: A torch.nn.Module instance.
:param graph: A torch.fx.Graph instance.
:param prefix: A string to use as a name prefix for the new getattr node.
:param value: A value
:return: A getattr node corresponding to the given value.
"""

def get_new_attr_name(module: torch.nn.Module, prefix: str):
def get_attr_name(i: int):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

return prefix + str(i)

i = 0
attr_name = get_attr_name(i)
while hasattr(module, attr_name):
i += 1
attr_name = get_attr_name(i)
return attr_name

attr_name = get_new_attr_name(module, prefix.replace(".", "_"))
device = get_device(module)
new_value = value.detach().clone() if isinstance(value, torch.Tensor) else torch.tensor(value, device=device)
module.register_buffer(attr_name, new_value)
attr_node = graph.create_node("get_attr", attr_name)
return attr_node
Loading