Skip to content

Commit 4ef35e0

Browse files
authored
[OpenVINO] Fix Quantizer for PTQ (#15891)
1 parent 93bf861 commit 4ef35e0

File tree

1 file changed

+16
-12
lines changed

1 file changed

+16
-12
lines changed

backends/openvino/quantizer/quantizer.py

Lines changed: 16 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -20,12 +20,19 @@
2020
INT8WeightObserver,
2121
)
2222
from nncf.common.graph.graph import NNCFGraph # type: ignore[import-untyped]
23+
from nncf.common.logging import nncf_logger # type: ignore[import-untyped]
24+
from nncf.quantization.algorithms.min_max.algorithm import ( # type: ignore[import-untyped]
25+
MinMaxQuantization,
26+
)
2327
from nncf.quantization.algorithms.weight_compression.config import ( # type: ignore[import-untyped]
2428
WeightCompressionParameters,
2529
)
2630
from nncf.quantization.quantize_model import ( # type: ignore[import-untyped]
2731
get_weight_compression_configuration,
2832
)
33+
from nncf.torch.model_graph_manager import ( # type: ignore[import-untyped]
34+
get_weight_tensor_port_ids,
35+
)
2936
from torchao.quantization.pt2e import (
3037
HistogramObserver,
3138
PerChannelMinMaxObserver,
@@ -105,16 +112,15 @@ def __init__(
105112
else:
106113
preset = None
107114
model_type = nncf.parameters.ModelType.TRANSFORMER
108-
self._algo = (
109-
nncf.quantization.algorithms.min_max.algorithm.MinMaxQuantization(
110-
preset=preset, model_type=model_type, **kwargs
111-
)
115+
self._algo = MinMaxQuantization(
116+
preset=preset, model_type=model_type, **kwargs
112117
)
113118
else:
119+
compression_mode = mode.value.replace(
120+
"wo", ""
121+
) # Mode value has to match NNCF CompressWeightsMode
114122
weight_compression_configuration = get_weight_compression_configuration(
115-
mode.value.replace(
116-
"wo", ""
117-
), # Mode value has to match NNCF CompressWeightsMode
123+
nncf.CompressWeightsMode(compression_mode),
118124
**kwargs,
119125
)
120126
subset_size = 1 # Doesn't really matter in this case since it is data-free. Should just be +ve
@@ -354,12 +360,10 @@ def _get_weight_edge(
354360
:return: Edge represented by a Tuple of (weight_node, target_node), where weight_node is the FX node supplying the weight.
355361
"""
356362
nncf_node = nncf_graph.get_node_by_name(target_node.name)
357-
weights_ports_ids = nncf.torch.model_graph_manager.get_weight_tensor_port_ids(
358-
nncf_node, nncf_graph
359-
)
363+
weights_ports_ids = get_weight_tensor_port_ids(nncf_node, nncf_graph)
360364
if len(weights_ports_ids) > 1:
361365
# TODO(dlyakhov): support quantization for nodes with several weights
362-
nncf.common.logging.nncf_logger.warning(
366+
nncf_logger.warning(
363367
f"Quantization of the weighted node {target_node.name}"
364368
" is not yet supported by the OpenVINOQuantizer."
365369
f" Only the weight on port ID {weights_ports_ids[0]} will be quantized."
@@ -384,7 +388,7 @@ def _get_edge_or_node(
384388
"""
385389
ip = qp.insertion_point
386390
if qp.is_weight_quantization_point():
387-
OpenVINOQuantizer._get_weight_edge(target_node, nncf_graph)
391+
return OpenVINOQuantizer._get_weight_edge(target_node, nncf_graph)
388392

389393
if ip.input_port_id is None:
390394
return target_node

0 commit comments

Comments
 (0)