Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 12 additions & 1 deletion src/nncf/common/quantization/quantizer_setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
from nncf.common.quantization.structs import NonWeightQuantizerId
from nncf.common.quantization.structs import QuantizationScheme as QuantizationMode
from nncf.common.quantization.structs import QuantizerConfig
from nncf.common.quantization.structs import TypedQuantizerConfig
from nncf.common.quantization.structs import UnifiedScaleType
from nncf.common.quantization.structs import WeightQuantizerId
from nncf.common.stateful_classes_registry import CommonStatefulClassesRegistry
Expand Down Expand Up @@ -193,9 +194,19 @@ def from_state(cls, state: dict[str, Any]) -> "SingleConfigQuantizationPoint":
insertion_point_cls_name = state[cls._state_names.INSERTION_POINT_CLASS_NAME]
insertion_point_cls = CommonStatefulClassesRegistry.get_registered_class(insertion_point_cls_name)
insertion_point = insertion_point_cls.from_state(state[cls._state_names.INSERTION_POINT]) # type: ignore
qconfig_state = state[cls._state_names.QCONFIG]
# Need to instantiate TypedQuantizerConfig
# to support additional fields used by ExecuTorch-specific quantizer configs.
# TODO (dlyakhov): Refactor and generalize quantizer config deserialization to cleanly handle both
# standard and extended formats without relying on manual key comparison (ticket 170078).
if QuantizerConfig().__dict__.keys() == qconfig_state.keys():
qconfig = QuantizerConfig.from_state(qconfig_state)
else:
qconfig = TypedQuantizerConfig.from_state(qconfig_state)

kwargs = {
cls._state_names.INSERTION_POINT: insertion_point,
cls._state_names.QCONFIG: QuantizerConfig.from_state(state[cls._state_names.QCONFIG]),
cls._state_names.QCONFIG: qconfig,
cls._state_names.NAMES_OF_QUANTIZED_OPS: state[cls._state_names.NAMES_OF_QUANTIZED_OPS],
}
return cls(**kwargs)
Expand Down
43 changes: 42 additions & 1 deletion src/nncf/common/quantization/structs.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@

from copy import deepcopy
from enum import Enum
from typing import Any, Optional
from typing import Any, Literal, Optional

import nncf
from nncf.common.graph import NNCFNode
Expand All @@ -22,6 +22,9 @@
from nncf.config.schemata.defaults import QUANTIZATION_PER_CHANNEL
from nncf.parameters import StrEnum
from nncf.parameters import TargetDevice
from nncf.tensor.definitions import TensorDataType

IntDtype = Literal[TensorDataType.int8, TensorDataType.uint8]


@api()
Expand Down Expand Up @@ -421,3 +424,41 @@ def get_params_configured_by_preset(self, quant_group: QuantizerGroup) -> dict[s
if quant_group == QuantizerGroup.ACTIVATIONS and self == QuantizationPreset.MIXED:
return {"mode": QuantizationScheme.ASYMMETRIC}
return {"mode": QuantizationScheme.SYMMETRIC}


class TypedQuantizerConfig(QuantizerConfig):
"""
Extended configuration class for quantizers, including destination integer dtype.
"""

def __init__(
self,
num_bits: int = QUANTIZATION_BITS,
mode: QuantizationScheme = QuantizationScheme.SYMMETRIC,
signedness_to_force: Optional[bool] = None,
per_channel: bool = QUANTIZATION_PER_CHANNEL,
narrow_range: bool = QUANTIZATION_NARROW_RANGE,
dest_dtype: IntDtype = TensorDataType.int8,
):
"""
:param num_bits: Bitwidth of the quantization.
:param mode: The mode of quantization (symmetric or asymmetric).
:param signedness_to_force: True if the quantizer *must* be signed, False if *must* be unsigned,
None if the signed/unsigned attribute should be determined based on the incoming activation
statistics during range initialization.
:param per_channel: True for per-channel quantization, False for per-tensor.
:param narrow_range: True if the range of quantized values should be narrowed as compared to the
naive case, False if all 2^`num_bits` quantizations should be used.
:param dest_dtype: Target integer data type for quantized values.
"""
super().__init__(num_bits, mode, signedness_to_force, per_channel, narrow_range)
self.dest_dtype = dest_dtype

def __str__(self) -> str:
retval = super().__str__()
return retval + " DestDtype: {self._dest_dtype}"

def get_state(self) -> dict[str, Any]:
state = super().get_state()
state["dest_dtype"] = self.dest_dtype
return state
Original file line number Diff line number Diff line change
Expand Up @@ -29,9 +29,10 @@
from nncf.common.quantization.quantizer_setup import SingleConfigQuantizerSetup
from nncf.common.quantization.quantizer_setup import WeightQuantizationInsertionPoint
from nncf.common.quantization.structs import QuantizationScheme as QuantizationMode
from nncf.common.quantization.structs import QuantizerConfig
from nncf.common.quantization.structs import TypedQuantizerConfig
from nncf.experimental.quantization.quantizer import Quantizer
from nncf.experimental.torch.fx.nncf_graph_builder import GraphConverter
from nncf.tensor.definitions import TensorDataType

EdgeOrNode = Union[tuple[torch.fx.Node, torch.fx.Node]]

Expand Down Expand Up @@ -71,15 +72,15 @@ def _get_quantization_points(
from_node: torch.fx.Node,
to_nodes: list[torch.fx.Node],
annotated_model: torch.fx.GraphModule,
qconfig: QuantizerConfig,
qconfig: TypedQuantizerConfig,
) -> list[QuantizationPointBase]:
"""
Creates quantization points based on the nodes and edges.

:param from_node: The originating node in the computation graph.
:param to_nodes: The list of destination nodes of the from_node.
:param annotated_model: The torch.fx.GraphModule instance.
:param qconfig: The torch.ao quantization configuration.
:param qconfig: The TorchFX quantization configuration.
:return: A list of NNCF quantization points.
"""
to_n = to_nodes[0]
Expand Down Expand Up @@ -159,15 +160,19 @@ def get_quantizer_config_from_annotated_model(annotated: torch.fx.GraphModule) -
msg = f"Unknown qscheme: {qspec.qscheme}"
raise nncf.InternalError(msg)

signed = qspec.dtype is torch.int8
dtype = TensorDataType.int8 if qspec.dtype is torch.int8 else TensorDataType.uint8
mode = (
QuantizationMode.SYMMETRIC
if qspec.qscheme in [torch.per_channel_symmetric, torch.per_tensor_symmetric]
else QuantizationMode.ASYMMETRIC
)
narrow_range = qspec.quant_min % 2 != 0
qconfig = QuantizerConfig(
mode=mode, signedness_to_force=signed, per_channel=per_channel, narrow_range=narrow_range
narrow_range = qspec.quant_max - qspec.quant_min == 254
qconfig = TypedQuantizerConfig(
mode=mode,
signedness_to_force=False,
per_channel=per_channel,
narrow_range=narrow_range,
dest_dtype=dtype,
)

joined_edges = defaultdict(list)
Expand Down
5 changes: 5 additions & 0 deletions src/nncf/parameters.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
# limitations under the License.

from enum import Enum
from typing import Any

from nncf.common.utils.api_marker import api

Expand All @@ -18,6 +19,10 @@ class StrEnum(str, Enum):
def __str__(self) -> str:
return str(self.value)

@staticmethod
def _generate_next_value_(name: str, start: int, count: int, last_values: list[Any]) -> Any:
return name.lower()


@api(canonical_alias="nncf.TargetDevice")
class TargetDevice(StrEnum):
Expand Down
136 changes: 73 additions & 63 deletions src/nncf/quantization/algorithms/min_max/torch_fx_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,9 @@
from nncf.common.graph.transformations.commands import TransformationCommand
from nncf.common.hardware.config import HWConfig
from nncf.common.quantization.quantizer_propagation.structs import QuantizationTrait
from nncf.common.quantization.structs import QuantizationScheme
from nncf.common.quantization.structs import QuantizerConfig
from nncf.common.quantization.structs import TypedQuantizerConfig
from nncf.experimental.common.tensor_statistics.collectors import REDUCERS_MAP
from nncf.experimental.common.tensor_statistics.collectors import TensorReducerBase
from nncf.experimental.torch.fx.commands import FXApplyTransformationCommand
Expand All @@ -35,6 +37,7 @@
from nncf.quantization.fake_quantize import FakeConvertParameters
from nncf.quantization.fake_quantize import FakeQuantizeParameters
from nncf.quantization.range_estimator import StatisticsType
from nncf.tensor.definitions import TensorDataType
from nncf.torch.graph.graph import PTNNCFGraph
from nncf.torch.graph.graph import PTTargetPoint
from nncf.torch.graph.operator_metatypes import ELEMENTWISE_OPERATIONS
Expand All @@ -46,12 +49,7 @@
from nncf.torch.model_graph_manager import is_matmul_with_constant
from nncf.torch.nncf_network import NNCFNetwork
from nncf.torch.quantization.default_quantization import DEFAULT_PT_QUANT_TRAIT_TO_OP_DICT
from nncf.torch.quantization.layers import QUANTIZATION_MODULES
from nncf.torch.quantization.layers import AsymmetricQuantizer
from nncf.torch.quantization.layers import BaseQuantizer
from nncf.torch.quantization.layers import PTQuantizerSpec
from nncf.torch.quantization.layers import get_scale_shape
from nncf.torch.quantization.strip import convert_to_torch_fakequantizer
from nncf.torch.quantization.quantize_functions import get_scale_zp_from_input_low_input_high


class FXMinMaxAlgoBackend(MinMaxAlgoBackend):
Expand Down Expand Up @@ -179,63 +177,83 @@ def get_weight_config(config: QuantizerConfig, model: NNCFNetwork) -> QuantizerC
return config

@staticmethod
def _get_input_scale_shape(
nncf_graph: NNCFGraph, target_point: PTTargetPoint, per_channel: bool
) -> tuple[tuple[int, ...], tuple[int, ...], int]:
is_weights = target_point.is_weight_target_point()
if is_weights:
def _get_channel_axis(is_weight_quantizer: bool) -> int:
if is_weight_quantizer:
# TODO(dlyakhov): support transpose conv/ make channel_idx common
channel_idx = 0
else:
channel_idx = 1 # channel dim for activations

input_shape = nncf_graph.get_input_shape_for_insertion_point(target_point)
scale_shape = tuple(
get_scale_shape(input_shape, is_weights=is_weights, per_channel=per_channel, channel_idx=channel_idx)
)

return input_shape, scale_shape, channel_idx
return 0
return 1

@staticmethod
def _create_quantizer(
quantizer_config: QuantizerConfig,
scale_shape: tuple,
parameters: FakeQuantizeParameters,
target_type: TargetType,
is_weight_quantizer: bool,
) -> FakeQuantize:
mode = quantizer_config.mode
quantizer_cls = QUANTIZATION_MODULES.get(mode)
quantizer_spec = PTQuantizerSpec.from_config(
quantizer_config,
narrow_range=quantizer_config.narrow_range,
scale_shape=scale_shape,
half_range=False,
logarithm_scale=False,
is_quantized_on_export=False,
compression_lr_multiplier=None,
)
quantizer = quantizer_cls(quantizer_spec)
per_channel = quantizer_config.per_channel
dtype = None
if isinstance(quantizer_config, TypedQuantizerConfig):
dtype = quantizer_config.dest_dtype

# Fill it with minmax
# TODO(dlyakhov) Prevent creation of intermediate objects like nncf quantizer.
FXMinMaxAlgoBackend._fill_quantizer_parameters(quantizer, parameters, quantizer_spec.scale_shape)
# Convert to the torch fake quantizer
torch_fq = convert_to_torch_fakequantizer(quantizer)
return torch_fq
if dtype not in [TensorDataType.int8, TensorDataType.uint8]:
msg = f"Quantization configurations with dest_dtype=={dtype} are not supported."
raise nncf.ParameterNotSupportedError(msg)

@staticmethod
def _fill_quantizer_parameters(quantizer: BaseQuantizer, parameters: FakeQuantizeParameters, scale_shape) -> None:
if isinstance(quantizer, AsymmetricQuantizer):
quantizer.input_low = torch.nn.Parameter(parameters.input_low.data.reshape(scale_shape))
input_range = parameters.input_high - parameters.input_low
# Subtract eps from the input_range to make quantizer parameters equal to
# original parameters on the forward call.
quantizer.input_range = torch.nn.Parameter((input_range.data - quantizer.eps).reshape(scale_shape))
elif quantizer_config.mode != QuantizationScheme.SYMMETRIC:
dtype = TensorDataType.uint8
else:
dtype = (
TensorDataType.int8
if quantizer_config.signedness_to_force or torch.any(parameters.input_low.data < 0.0)
else TensorDataType.uint8
)

if per_channel:
observer = torch.ao.quantization.observer.PerChannelMinMaxObserver
else:
observer = torch.ao.quantization.observer.MinMaxObserver

if dtype is TensorDataType.int8:
level_high = 127
level_low = -128
else:
level_high = 255
level_low = 0

if quantizer_config.narrow_range:
if level_low < 0:
level_low += 1
else:
level_high -= 1

if quantizer_config.mode == QuantizationScheme.SYMMETRIC:
qscheme = torch.per_channel_symmetric if per_channel else torch.per_tensor_symmetric
else:
quantizer.signed = bool(torch.any(parameters.input_low.data < 0))
# Subtract eps from the scale to make quantizer parameters equal to
# original parameters on the forward call.
quantizer.scale = torch.nn.Parameter((parameters.input_high.data - quantizer.eps).reshape(scale_shape))
qscheme = torch.per_channel_affine if per_channel else torch.per_tensor_affine

scale, zero_point = get_scale_zp_from_input_low_input_high(
level_low, level_high, parameters.input_low.data, parameters.input_high.data
)

scale = scale.view(-1)
zero_point = zero_point.view(-1)

fakequantizer = FakeQuantize(
observer=observer,
quant_max=level_high,
quant_min=level_low,
dtype=torch.qint8 if dtype is TensorDataType.int8 else torch.quint8,
qscheme=qscheme,
eps=1e-16,
)

fakequantizer.scale = scale
fakequantizer.zero_point = zero_point
if per_channel:
fakequantizer.ch_axis = FXMinMaxAlgoBackend._get_channel_axis(is_weight_quantizer)

# Disable observer to save parameters
fakequantizer.disable_observer()
return fakequantizer

@staticmethod
def create_quantizer_insertion_command(
Expand All @@ -244,12 +262,8 @@ def create_quantizer_insertion_command(
quantizer_config: QuantizerConfig,
parameters: FakeQuantizeParameters,
) -> FXApplyTransformationCommand:
_, scale_shape, _ = FXMinMaxAlgoBackend._get_input_scale_shape(
nncf_graph, target_point, quantizer_config.per_channel
)

quantizer = FXMinMaxAlgoBackend._create_quantizer(
quantizer_config, scale_shape, parameters, target_point.target_type
quantizer_config, parameters, target_point.is_weight_target_point()
)
transformation = qdq_insertion_transformation_builder(quantizer, [target_point])
return FXApplyTransformationCommand(transformation)
Expand All @@ -261,12 +275,8 @@ def create_unified_scales_quantizers_insertion_commands(
quantizer_config: QuantizerConfig,
parameters: FakeQuantizeParameters,
) -> list[PTSharedFnInsertionCommand]:
_, scale_shape, _ = FXMinMaxAlgoBackend._get_input_scale_shape(
nncf_graph, target_points[0], quantizer_config.per_channel
)

quantizer = FXMinMaxAlgoBackend._create_quantizer(
quantizer_config, scale_shape, parameters, target_points[0].target_type
quantizer_config, parameters, target_points[0].is_weight_target_point()
)

transformations = []
Expand Down
4 changes: 3 additions & 1 deletion src/nncf/tensor/definitions.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,8 @@
from enum import auto
from typing import Optional, Union

from nncf.parameters import StrEnum

T_SHAPE_ARRAY = tuple[int, ...]
T_SHAPE = Union[int, T_SHAPE_ARRAY]
T_AXIS = Optional[T_SHAPE]
Expand All @@ -31,7 +33,7 @@ class TensorBackend(Enum):
ov = auto()


class TensorDataType(Enum):
class TensorDataType(StrEnum):
"""
Enum representing the different tensor data types.
"""
Expand Down
Loading