Skip to content

Commit 45c43a7

Browse files
experimental/structs.py -> quantization/structs.py
1 parent 714d77d commit 45c43a7

File tree

7 files changed

+57
-75
lines changed

7 files changed

+57
-75
lines changed

src/nncf/common/quantization/quantizer_setup.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -21,11 +21,11 @@
2121
from nncf.common.quantization.structs import NonWeightQuantizerId
2222
from nncf.common.quantization.structs import QuantizationScheme as QuantizationMode
2323
from nncf.common.quantization.structs import QuantizerConfig
24+
from nncf.common.quantization.structs import TypedQuantizerConfig
2425
from nncf.common.quantization.structs import UnifiedScaleType
2526
from nncf.common.quantization.structs import WeightQuantizerId
2627
from nncf.common.stateful_classes_registry import CommonStatefulClassesRegistry
2728
from nncf.config.schemata.defaults import QUANTIZATION_NARROW_RANGE
28-
from nncf.experimental.quantization.structs import ExtendedQuantizerConfig
2929

3030
QuantizationPointId = int
3131

@@ -195,14 +195,14 @@ def from_state(cls, state: dict[str, Any]) -> "SingleConfigQuantizationPoint":
195195
insertion_point_cls = CommonStatefulClassesRegistry.get_registered_class(insertion_point_cls_name)
196196
insertion_point = insertion_point_cls.from_state(state[cls._state_names.INSERTION_POINT]) # type: ignore
197197
qconfig_state = state[cls._state_names.QCONFIG]
198-
# Need to instantiate ExtendedQuantizerConfig
198+
# Need to instantiate TypedQuantizerConfig
199199
# to support additional fields used by ExecuTorch-specific quantizer configs.
200200
# TODO (dlyakhov): Refactor and generalize quantizer config deserialization to cleanly handle both
201201
# standard and extended formats without relying on manual key comparison (ticket 170078).
202202
if QuantizerConfig().__dict__.keys() == qconfig_state.keys():
203203
qconfig = QuantizerConfig.from_state(qconfig_state)
204204
else:
205-
qconfig = ExtendedQuantizerConfig.from_state(qconfig_state)
205+
qconfig = TypedQuantizerConfig.from_state(qconfig_state)
206206

207207
kwargs = {
208208
cls._state_names.INSERTION_POINT: insertion_point,

src/nncf/common/quantization/structs.py

Lines changed: 42 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111

1212
from copy import deepcopy
1313
from enum import Enum
14-
from typing import Any, Optional
14+
from typing import Any, Literal, Optional
1515

1616
import nncf
1717
from nncf.common.graph import NNCFNode
@@ -22,6 +22,9 @@
2222
from nncf.config.schemata.defaults import QUANTIZATION_PER_CHANNEL
2323
from nncf.parameters import StrEnum
2424
from nncf.parameters import TargetDevice
25+
from nncf.tensor.definitions import TensorDataType
26+
27+
IntDtype = Literal[TensorDataType.int8, TensorDataType.uint8]
2528

2629

2730
@api()
@@ -421,3 +424,41 @@ def get_params_configured_by_preset(self, quant_group: QuantizerGroup) -> dict[s
421424
if quant_group == QuantizerGroup.ACTIVATIONS and self == QuantizationPreset.MIXED:
422425
return {"mode": QuantizationScheme.ASYMMETRIC}
423426
return {"mode": QuantizationScheme.SYMMETRIC}
427+
428+
429+
class TypedQuantizerConfig(QuantizerConfig):
430+
"""
431+
Extended configuration class for quantizers, including destination integer dtype.
432+
"""
433+
434+
def __init__(
435+
self,
436+
num_bits: int = QUANTIZATION_BITS,
437+
mode: QuantizationScheme = QuantizationScheme.SYMMETRIC,
438+
signedness_to_force: Optional[bool] = None,
439+
per_channel: bool = QUANTIZATION_PER_CHANNEL,
440+
narrow_range: bool = QUANTIZATION_NARROW_RANGE,
441+
dest_dtype: IntDtype = TensorDataType.int8,
442+
):
443+
"""
444+
:param num_bits: Bitwidth of the quantization.
445+
:param mode: The mode of quantization (symmetric or asymmetric).
446+
:param signedness_to_force: True if the quantizer *must* be signed, False if *must* be unsigned,
447+
None if the signed/unsigned attribute should be determined based on the incoming activation
448+
statistics during range initialization.
449+
:param per_channel: True for per-channel quantization, False for per-tensor.
450+
:param narrow_range: True if the range of quantized values should be narrowed as compared to the
451+
naive case, False if all 2^`num_bits` quantizations should be used.
452+
:param dest_dtype: Target integer data type for quantized values.
453+
"""
454+
super().__init__(num_bits, mode, signedness_to_force, per_channel, narrow_range)
455+
self.dest_dtype = dest_dtype
456+
457+
def __str__(self) -> str:
458+
retval = super().__str__()
459+
return retval + " DestDtype: {self._dest_dtype}"
460+
461+
def get_state(self) -> dict[str, Any]:
462+
state = super().get_state()
463+
state["dest_dtype"] = self.dest_dtype
464+
return state

src/nncf/experimental/quantization/structs.py

Lines changed: 0 additions & 59 deletions
This file was deleted.

src/nncf/experimental/torch/fx/quantization/quantizer/torch_ao_adapter.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -29,8 +29,8 @@
2929
from nncf.common.quantization.quantizer_setup import SingleConfigQuantizerSetup
3030
from nncf.common.quantization.quantizer_setup import WeightQuantizationInsertionPoint
3131
from nncf.common.quantization.structs import QuantizationScheme as QuantizationMode
32+
from nncf.common.quantization.structs import TypedQuantizerConfig
3233
from nncf.experimental.quantization.quantizer import Quantizer
33-
from nncf.experimental.quantization.structs import ExtendedQuantizerConfig
3434
from nncf.experimental.torch.fx.nncf_graph_builder import GraphConverter
3535
from nncf.tensor.definitions import TensorDataType
3636

@@ -72,7 +72,7 @@ def _get_quantization_points(
7272
from_node: torch.fx.Node,
7373
to_nodes: list[torch.fx.Node],
7474
annotated_model: torch.fx.GraphModule,
75-
qconfig: ExtendedQuantizerConfig,
75+
qconfig: TypedQuantizerConfig,
7676
) -> list[QuantizationPointBase]:
7777
"""
7878
Creates quantization points based on the nodes and edges.
@@ -167,7 +167,7 @@ def get_quantizer_config_from_annotated_model(annotated: torch.fx.GraphModule) -
167167
else QuantizationMode.ASYMMETRIC
168168
)
169169
narrow_range = qspec.quant_max - qspec.quant_min == 254
170-
qconfig = ExtendedQuantizerConfig(
170+
qconfig = TypedQuantizerConfig(
171171
mode=mode,
172172
signedness_to_force=False,
173173
per_channel=per_channel,

src/nncf/quantization/algorithms/min_max/torch_fx_backend.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -25,9 +25,9 @@
2525
from nncf.common.quantization.quantizer_propagation.structs import QuantizationTrait
2626
from nncf.common.quantization.structs import QuantizationScheme
2727
from nncf.common.quantization.structs import QuantizerConfig
28+
from nncf.common.quantization.structs import TypedQuantizerConfig
2829
from nncf.experimental.common.tensor_statistics.collectors import REDUCERS_MAP
2930
from nncf.experimental.common.tensor_statistics.collectors import TensorReducerBase
30-
from nncf.experimental.quantization.structs import ExtendedQuantizerConfig
3131
from nncf.experimental.torch.fx.commands import FXApplyTransformationCommand
3232
from nncf.experimental.torch.fx.model_utils import get_target_point
3333
from nncf.experimental.torch.fx.transformations import qdq_insertion_transformation_builder
@@ -195,7 +195,7 @@ def _create_quantizer(
195195
) -> FakeQuantize:
196196
per_channel = quantizer_config.per_channel
197197
dtype = None
198-
if isinstance(quantizer_config, ExtendedQuantizerConfig):
198+
if isinstance(quantizer_config, TypedQuantizerConfig):
199199
dtype = quantizer_config.dest_dtype
200200

201201
if dtype not in [TensorDataType.int8, TensorDataType.uint8]:

tests/torch/quantization/test_serialize_to_json.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@
1919
from nncf.common.quantization.quantizer_setup import SingleConfigQuantizerSetup
2020
from nncf.common.quantization.quantizer_setup import WeightQuantizationInsertionPoint
2121
from nncf.common.quantization.structs import QuantizerConfig
22-
from nncf.experimental.quantization.structs import ExtendedQuantizerConfig
22+
from nncf.common.quantization.structs import TypedQuantizerConfig
2323
from nncf.torch.dynamic_graph.context import Scope
2424
from nncf.torch.graph.transformations.commands import PTTargetPoint
2525
from tests.cross_fw.shared.serialization import check_serialization
@@ -95,7 +95,7 @@ def test_quantizer_setup_serialization():
9595
scqp_2 = SingleConfigQuantizationPoint(aqip, qc, directly_quantized_operator_node_names=[str(scope)])
9696
check_serialization(scqp_2)
9797

98-
ex_qc = ExtendedQuantizerConfig()
98+
ex_qc = TypedQuantizerConfig()
9999
scqp_ex = SingleConfigQuantizationPoint(aqip, ex_qc, directly_quantized_operator_node_names=[str(scope)])
100100
check_serialization(scqp_ex)
101101

tests/torch2/fx/test_calculation_quantizer_params.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -17,12 +17,12 @@
1717
import torch
1818

1919
import nncf
20+
from nncf.common.quantization.structs import IntDtype
2021
from nncf.common.quantization.structs import QuantizationScheme as QuantizationMode
2122
from nncf.common.quantization.structs import QuantizerConfig
2223
from nncf.common.quantization.structs import QuantizerGroup
24+
from nncf.common.quantization.structs import TypedQuantizerConfig
2325
from nncf.experimental.common.tensor_statistics.statistics import MinMaxTensorStatistic
24-
from nncf.experimental.quantization.structs import ExtendedQuantizerConfig
25-
from nncf.experimental.quantization.structs import IntDtype
2626
from nncf.quantization.algorithms.min_max.torch_fx_backend import FXMinMaxAlgoBackend
2727
from nncf.quantization.fake_quantize import FakeQuantizeParameters
2828
from nncf.quantization.fake_quantize import calculate_quantizer_parameters
@@ -87,7 +87,7 @@ def test_quantizer_params_sym(case_to_test: CaseQuantParams, dtype: Optional[Int
8787
narrow_range = case_to_test.narrow_range
8888
mode = QuantizationMode.SYMMETRIC
8989
signedness_to_force = None
90-
qconfig = ExtendedQuantizerConfig(
90+
qconfig = TypedQuantizerConfig(
9191
num_bits=8,
9292
mode=mode,
9393
per_channel=per_ch,
@@ -387,7 +387,7 @@ def test_quantizer_params_asym(case_to_test: CaseQuantParams, ref_zp: Union[int,
387387
per_ch = case_to_test.per_channel
388388
narrow_range = case_to_test.narrow_range
389389
mode = QuantizationMode.ASYMMETRIC
390-
qconfig = ExtendedQuantizerConfig(
390+
qconfig = TypedQuantizerConfig(
391391
num_bits=8,
392392
mode=mode,
393393
per_channel=per_ch,
@@ -452,7 +452,7 @@ def _check_q_min_q_max(quantizer, signed, narrow_range):
452452
],
453453
)
454454
def test_extended_q_config_non_supported_dest_dtype(dest_dtype):
455-
qconfig = ExtendedQuantizerConfig(dest_dtype=dest_dtype)
455+
qconfig = TypedQuantizerConfig(dest_dtype=dest_dtype)
456456
params = FakeQuantizeParameters(-1.0, 1.0, -1.0, 1.0, 255)
457457
with pytest.raises(nncf.ParameterNotSupportedError):
458458
FXMinMaxAlgoBackend._create_quantizer(quantizer_config=qconfig, channel_axis=1, parameters=params)

0 commit comments

Comments
 (0)