Skip to content

Commit adaadaa

Browse files
Comments
1 parent 7a2c141 commit adaadaa

File tree

8 files changed

+70
-25
lines changed

8 files changed

+70
-25
lines changed

src/nncf/experimental/quantization/structs.py

Lines changed: 8 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -9,23 +9,17 @@
99
# See the License for the specific language governing permissions and
1010
# limitations under the License.
1111

12-
from typing import Any, Optional
12+
from typing import Any, Literal, Optional
1313

14+
import nncf
1415
from nncf.common.quantization.structs import QuantizationScheme
1516
from nncf.common.quantization.structs import QuantizerConfig
1617
from nncf.config.schemata.defaults import QUANTIZATION_BITS
1718
from nncf.config.schemata.defaults import QUANTIZATION_NARROW_RANGE
1819
from nncf.config.schemata.defaults import QUANTIZATION_PER_CHANNEL
19-
from nncf.parameters import StrEnum
20+
from nncf.tensor.definitions import TensorDataType
2021

21-
22-
class IntDtype(StrEnum):
23-
"""
24-
Enum of possible integer types.
25-
"""
26-
27-
INT8 = "INT8"
28-
UINT8 = "UINT8"
22+
IntDtype = Literal[TensorDataType.int8, TensorDataType.uint8]
2923

3024

3125
class ExtendedQuantizerConfig(QuantizerConfig):
@@ -40,7 +34,7 @@ def __init__(
4034
signedness_to_force: Optional[bool] = None,
4135
per_channel: bool = QUANTIZATION_PER_CHANNEL,
4236
narrow_range: bool = QUANTIZATION_NARROW_RANGE,
43-
dest_dtype: IntDtype = IntDtype.INT8,
37+
dest_dtype: IntDtype = TensorDataType.int8,
4438
):
4539
"""
4640
:param num_bits: Bitwidth of the quantization.
@@ -54,6 +48,9 @@ def __init__(
5448
:param dest_dtype: Target integer data type for quantized values.
5549
"""
5650
super().__init__(num_bits, mode, signedness_to_force, per_channel, narrow_range)
51+
if dest_dtype not in [TensorDataType.int8, TensorDataType.uint8]:
52+
msg = f"Quantization configurations with dest_dtype=={dest_dtype} are not supported."
53+
raise nncf.ParameterNotSupportedError(msg)
5754
self.dest_dtype = dest_dtype
5855

5956
def __str__(self) -> str:

src/nncf/experimental/torch/fx/quantization/quantizer/torch_ao_adapter.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -31,8 +31,8 @@
3131
from nncf.common.quantization.structs import QuantizationScheme as QuantizationMode
3232
from nncf.experimental.quantization.quantizer import Quantizer
3333
from nncf.experimental.quantization.structs import ExtendedQuantizerConfig
34-
from nncf.experimental.quantization.structs import IntDtype
3534
from nncf.experimental.torch.fx.nncf_graph_builder import GraphConverter
35+
from nncf.tensor.definitions import TensorDataType
3636

3737
EdgeOrNode = Union[tuple[torch.fx.Node, torch.fx.Node]]
3838

@@ -160,7 +160,7 @@ def get_quantizer_config_from_annotated_model(annotated: torch.fx.GraphModule) -
160160
msg = f"Unknown qscheme: {qspec.qscheme}"
161161
raise nncf.InternalError(msg)
162162

163-
dtype = IntDtype.INT8 if qspec.dtype is torch.int8 else IntDtype.UINT8
163+
dtype = TensorDataType.int8 if qspec.dtype is torch.int8 else TensorDataType.uint8
164164
mode = (
165165
QuantizationMode.SYMMETRIC
166166
if qspec.qscheme in [torch.per_channel_symmetric, torch.per_tensor_symmetric]

src/nncf/quantization/algorithms/min_max/torch_fx_backend.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,6 @@
2828
from nncf.experimental.common.tensor_statistics.collectors import REDUCERS_MAP
2929
from nncf.experimental.common.tensor_statistics.collectors import TensorReducerBase
3030
from nncf.experimental.quantization.structs import ExtendedQuantizerConfig
31-
from nncf.experimental.quantization.structs import IntDtype
3231
from nncf.experimental.torch.fx.commands import FXApplyTransformationCommand
3332
from nncf.experimental.torch.fx.model_utils import get_target_point
3433
from nncf.experimental.torch.fx.transformations import qdq_insertion_transformation_builder
@@ -38,6 +37,7 @@
3837
from nncf.quantization.fake_quantize import FakeConvertParameters
3938
from nncf.quantization.fake_quantize import FakeQuantizeParameters
4039
from nncf.quantization.range_estimator import StatisticsType
40+
from nncf.tensor.definitions import TensorDataType
4141
from nncf.torch.graph.graph import PTNNCFGraph
4242
from nncf.torch.graph.graph import PTTargetPoint
4343
from nncf.torch.graph.operator_metatypes import ELEMENTWISE_OPERATIONS
@@ -199,20 +199,20 @@ def _create_quantizer(
199199
if isinstance(quantizer_config, ExtendedQuantizerConfig):
200200
dtype = quantizer_config.dest_dtype
201201
elif quantizer_config.mode != QuantizationScheme.SYMMETRIC:
202-
dtype = IntDtype.UINT8
202+
dtype = TensorDataType.uint8
203203
else:
204204
dtype = (
205-
IntDtype.INT8
205+
TensorDataType.int8
206206
if quantizer_config.signedness_to_force or torch.any(parameters.input_low.data < 0.0)
207-
else IntDtype.UINT8
207+
else TensorDataType.uint8
208208
)
209209

210210
if per_channel:
211211
observer = torch.ao.quantization.observer.PerChannelMinMaxObserver
212212
else:
213213
observer = torch.ao.quantization.observer.MinMaxObserver
214214

215-
if dtype is IntDtype.INT8:
215+
if dtype is TensorDataType.int8:
216216
level_high = 127
217217
level_low = -128
218218
else:
@@ -241,7 +241,7 @@ def _create_quantizer(
241241
observer=observer,
242242
quant_max=level_high,
243243
quant_min=level_low,
244-
dtype=torch.qint8 if dtype is IntDtype.INT8 else torch.quint8,
244+
dtype=torch.qint8 if dtype is TensorDataType.int8 else torch.quint8,
245245
qscheme=qscheme,
246246
eps=1e-16,
247247
)

src/nncf/tensor/definitions.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,10 @@ class TensorDataType(StrEnum):
3838
Enum representing the different tensor data types.
3939
"""
4040

41+
@staticmethod
42+
def _generate_next_value_(name, start, count, last_values):
43+
return name.lower()
44+
4145
float16 = auto()
4246
bfloat16 = auto()
4347
float32 = auto()
Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,38 @@
1+
# Copyright (c) 2025 Intel Corporation
2+
# Licensed under the Apache License, Version 2.0 (the "License");
3+
# you may not use this file except in compliance with the License.
4+
# You may obtain a copy of the License at
5+
# http://www.apache.org/licenses/LICENSE-2.0
6+
# Unless required by applicable law or agreed to in writing, software
7+
# distributed under the License is distributed on an "AS IS" BASIS,
8+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
9+
# See the License for the specific language governing permissions and
10+
# limitations under the License.
11+
12+
import pytest
13+
14+
import nncf
15+
from nncf.experimental.quantization.structs import ExtendedQuantizerConfig
16+
from nncf.tensor.definitions import TensorDataType
17+
18+
19+
@pytest.mark.parametrize(
20+
"dest_dtype",
21+
[
22+
TensorDataType.float16,
23+
TensorDataType.bfloat16,
24+
TensorDataType.float32,
25+
TensorDataType.float64,
26+
TensorDataType.f8e4m3,
27+
TensorDataType.f8e5m2,
28+
TensorDataType.nf4,
29+
TensorDataType.int32,
30+
TensorDataType.int64,
31+
TensorDataType.uint4,
32+
TensorDataType.int4,
33+
None,
34+
],
35+
)
36+
def test_extended_q_config_non_supported_dest_dtype(dest_dtype):
37+
with pytest.raises(nncf.ParameterNotSupportedError):
38+
ExtendedQuantizerConfig(dest_dtype=dest_dtype)

tests/torch2/conftest.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -50,6 +50,7 @@ def pytest_configure(config: Config) -> None:
5050
regen_dot = config.getoption("--regen-ref-data", False)
5151
if regen_dot:
5252
os.environ["NNCF_TEST_REGEN_DOT"] = "1"
53+
os.environ["NNCF_TEST_REGEN_JSON"] = "1"
5354

5455
nncf_debug = config.getoption("--nncf-debug", False)
5556
if nncf_debug:

tests/torch2/fx/test_calculation_quantizer_params.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@
2626
from nncf.quantization.algorithms.min_max.torch_fx_backend import FXMinMaxAlgoBackend
2727
from nncf.quantization.fake_quantize import calculate_quantizer_parameters
2828
from nncf.tensor import Tensor
29+
from nncf.tensor.definitions import TensorDataType
2930

3031
INPUT_SHAPE = (2, 3, 4, 5)
3132

@@ -79,7 +80,7 @@ class CaseQuantParams:
7980

8081

8182
@pytest.mark.parametrize("case_to_test", SYM_CASES)
82-
@pytest.mark.parametrize("dtype", [IntDtype.UINT8, IntDtype.INT8])
83+
@pytest.mark.parametrize("dtype", [TensorDataType.uint8, TensorDataType.int8])
8384
def test_quantizer_params_sym(case_to_test: CaseQuantParams, dtype: Optional[IntDtype]):
8485
per_ch = case_to_test.per_channel
8586
narrow_range = case_to_test.narrow_range
@@ -97,7 +98,7 @@ def test_quantizer_params_sym(case_to_test: CaseQuantParams, dtype: Optional[Int
9798
quantizer = _get_quantizer(case_to_test, qconfig)
9899
assert quantizer.qscheme is torch.per_channel_symmetric if case_to_test.per_channel else torch.per_tensor_symmetric
99100

100-
signed = signedness_to_force or dtype is IntDtype.INT8
101+
signed = signedness_to_force or dtype is TensorDataType.int8
101102
if signed:
102103
assert torch.allclose(quantizer.zero_point, torch.tensor(0, dtype=torch.int8))
103104
else:
@@ -380,7 +381,7 @@ def test_quantizer_params_sym_nr(case_to_test: CaseQuantParams, ref_signed: bool
380381

381382

382383
@pytest.mark.parametrize("case_to_test,ref_zp", ASYM_CASES)
383-
@pytest.mark.parametrize("dtype", [IntDtype.UINT8, IntDtype.INT8])
384+
@pytest.mark.parametrize("dtype", [TensorDataType.uint8, TensorDataType.int8])
384385
def test_quantizer_params_asym(case_to_test: CaseQuantParams, ref_zp: Union[int, list[int]], dtype: Optional[IntDtype]):
385386
per_ch = case_to_test.per_channel
386387
narrow_range = case_to_test.narrow_range
@@ -397,7 +398,7 @@ def test_quantizer_params_asym(case_to_test: CaseQuantParams, ref_zp: Union[int,
397398
quantizer = _get_quantizer(case_to_test, qconfig)
398399
assert quantizer.qscheme is torch.per_channel_affine if case_to_test.per_channel else torch.per_tensor_affine
399400

400-
signed = dtype is IntDtype.INT8
401+
signed = dtype is TensorDataType.int8
401402
ref_zp = torch.tensor(ref_zp)
402403
if not signed:
403404
ref_zp += 127 if narrow_range else 128

tests/torch2/fx/test_quantizer.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -36,14 +36,14 @@
3636
import nncf
3737
from nncf.common.graph import NNCFGraph
3838
from nncf.common.utils.os import safe_open
39-
from nncf.experimental.quantization.structs import IntDtype
4039
from nncf.experimental.torch.fx import quantize_pt2e
4140
from nncf.experimental.torch.fx.nncf_graph_builder import GraphConverter
4241
from nncf.experimental.torch.fx.node_utils import get_graph_node_by_name
4342
from nncf.experimental.torch.fx.quantization.quantizer.openvino_adapter import OpenVINOQuantizerAdapter
4443
from nncf.experimental.torch.fx.quantization.quantizer.openvino_quantizer import OpenVINOQuantizer
4544
from nncf.experimental.torch.fx.quantization.quantizer.torch_ao_adapter import TorchAOQuantizerAdapter
4645
from nncf.experimental.torch.fx.quantization.quantizer.torch_ao_adapter import _get_edge_or_node_to_qspec
46+
from nncf.tensor.definitions import TensorDataType
4747
from tests.cross_fw.shared.nx_graph import compare_nx_graph_with_reference
4848
from tests.cross_fw.shared.paths import TEST_ROOT
4949
from tests.torch import test_models
@@ -256,6 +256,8 @@ def _normalize_qsetup_state(setup: dict[str, Any]) -> None:
256256
for qp in setup["quantization_points"].values():
257257
sorted_dq = sorted(qp[dq_key])
258258
qconfig = qp["qconfig"].copy()
259+
if "dest_dtype" in qconfig:
260+
qconfig["dest_dtype"] = "INT8" if qconfig["dest_dtype"] is TensorDataType.int8 else "UINT8"
259261
sorted_qps[f"{tuple(sorted_dq)}_{qp['qip_class']}"] = qconfig
260262
setup["quantization_points"] = sorted_qps
261263

@@ -285,7 +287,9 @@ def _normalize_nncf_graph(nncf_graph: NNCFGraph, fx_graph: torch.fx.Graph):
285287
idx += 1
286288
if node.node_type in ["dequantize_per_tensor", "dequantize_per_channel"]:
287289
source_node = get_graph_node_by_name(fx_graph, node.node_name)
288-
dtypes_map[new_node_name] = IntDtype.INT8 if source_node.args[-1] == torch.int8 else IntDtype.UINT8
290+
dtypes_map[new_node_name] = (
291+
TensorDataType.int8 if source_node.args[-1] == torch.int8 else TensorDataType.uint8
292+
)
289293
norm_nncf_graph.add_nncf_node(
290294
node_name=attrs[node.NODE_NAME_ATTR],
291295
node_type=attrs[node.NODE_TYPE_ATTR],

0 commit comments

Comments
 (0)