Skip to content

Commit 69796b5

Browse files
Comments
1 parent e6c4c31 commit 69796b5

File tree

4 files changed

+26
-11
lines changed

4 files changed

+26
-11
lines changed

src/nncf/experimental/quantization/structs.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111

1212
from typing import Any, Literal, Optional
1313

14+
import nncf
1415
from nncf.common.quantization.structs import QuantizationScheme
1516
from nncf.common.quantization.structs import QuantizerConfig
1617
from nncf.config.schemata.defaults import QUANTIZATION_BITS
@@ -49,7 +50,7 @@ def __init__(
4950
super().__init__(num_bits, mode, signedness_to_force, per_channel, narrow_range)
5051
if dest_dtype not in [TensorDataType.int8, TensorDataType.uint8]:
5152
msg = f"Quantization configurations with dest_dtype=={dest_dtype} are not supported."
52-
raise RuntimeError(msg)
53+
raise nncf.ParameterNotSupportedError(msg)
5354
self.dest_dtype = dest_dtype
5455

5556
def __str__(self) -> str:

tests/common/experimental/test_structs.py

Lines changed: 21 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -11,12 +11,28 @@
1111

1212
import pytest
1313

14+
import nncf
1415
from nncf.experimental.quantization.structs import ExtendedQuantizerConfig
1516
from nncf.tensor.definitions import TensorDataType
1617

1718

18-
def test_extended_q_config_non_supported_dest_dtype():
19-
with pytest.raises(RuntimeError):
20-
ExtendedQuantizerConfig(dest_dtype=TensorDataType.int4)
21-
with pytest.raises(RuntimeError):
22-
ExtendedQuantizerConfig(dest_dtype=None)
19+
@pytest.mark.parametrize(
20+
"dest_dtype",
21+
[
22+
TensorDataType.float16,
23+
TensorDataType.bfloat16,
24+
TensorDataType.float32,
25+
TensorDataType.float64,
26+
TensorDataType.f8e4m3,
27+
TensorDataType.f8e5m2,
28+
TensorDataType.nf4,
29+
TensorDataType.int32,
30+
TensorDataType.int64,
31+
TensorDataType.uint4,
32+
TensorDataType.int4,
33+
None,
34+
],
35+
)
36+
def test_extended_q_config_non_supported_dest_dtype(dest_dtype):
37+
with pytest.raises(nncf.ParameterNotSupportedError):
38+
ExtendedQuantizerConfig(dest_dtype=dest_dtype)

tests/post_training/pipelines/image_classification_base.py

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,6 @@
3232
from nncf.common.logging.track_progress import track
3333
from nncf.experimental.torch.fx import OpenVINOQuantizer
3434
from nncf.experimental.torch.fx import quantize_pt2e
35-
from nncf.torch import disable_patching
3635
from tests.post_training.pipelines.base import DEFAULT_VAL_THREADS
3736
from tests.post_training.pipelines.base import FX_BACKENDS
3837
from tests.post_training.pipelines.base import BackendType
@@ -130,7 +129,7 @@ def _validate(self) -> None:
130129
return []
131130

132131
def _compress_torch_ao(self, quantizer):
133-
with torch.no_grad(), disable_patching():
132+
with torch.no_grad():
134133
prepared_model = prepare_pt2e(self.model, quantizer)
135134
subset_size = self.compression_params.get("subset_size", 300)
136135
for data in islice(self.calibration_dataset.get_inference_data(), subset_size):
@@ -167,7 +166,7 @@ def _compress_nncf_pt2e(self, quantizer):
167166
if self.compression_params.get("model_type", False):
168167
smooth_quant = self.compression_params["model_type"] == nncf.ModelType.TRANSFORMER
169168

170-
with disable_patching(), torch.no_grad():
169+
with torch.no_grad():
171170
self.compressed_model = quantize_pt2e(
172171
self.model,
173172
quantizer,
@@ -186,7 +185,7 @@ def _compress(self):
186185

187186
return
188187
if self.backend in [BackendType.FX_TORCH, BackendType.CUDA_FX_TORCH]:
189-
with disable_patching(), torch.no_grad():
188+
with torch.no_grad():
190189
super()._compress()
191190
return
192191

tests/torch2/conftest.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -50,7 +50,6 @@ def pytest_configure(config: Config) -> None:
5050
regen_dot = config.getoption("--regen-ref-data", False)
5151
if regen_dot:
5252
os.environ["NNCF_TEST_REGEN_DOT"] = "1"
53-
os.environ["NNCF_TEST_REGEN_JSON"] = "1"
5453

5554
nncf_debug = config.getoption("--nncf-debug", False)
5655
if nncf_debug:

0 commit comments

Comments
 (0)