10
10
# limitations under the License.
11
11
12
12
from collections import defaultdict
13
- from typing import Dict , List , Optional , Tuple , Union
13
+ from typing import Dict , List , Optional , Tuple
14
14
15
15
import torch .fx
16
16
from torch .ao .quantization .observer import HistogramObserver
17
17
from torch .ao .quantization .observer import PerChannelMinMaxObserver
18
- from torch .ao .quantization .observer import MinMaxObserver
19
18
from torch .ao .quantization .quantizer .quantizer import EdgeOrNode
20
19
from torch .ao .quantization .quantizer .quantizer import QuantizationAnnotation
21
20
from torch .ao .quantization .quantizer .quantizer import QuantizationSpec
24
23
from torch .ao .quantization .quantizer .quantizer import SharedQuantizationSpec
25
24
26
25
import nncf
26
+ import nncf .common .quantization as q
27
+ import nncf .experimental .torch .fx as nncf_fx
28
+ import nncf .parameters as p
29
+ import nncf .quantization .advanced_parameters as advanced_p
27
30
from nncf .common .graph .graph import NNCFGraph
28
- from nncf .common .logging import nncf_logger
29
- from nncf .common .quantization .quantizer_propagation .solver import QuantizerPropagationRule
30
- from nncf .common .quantization .quantizer_setup import QuantizationPointBase
31
- from nncf .common .quantization .quantizer_setup import SingleConfigQuantizerSetup
32
- from nncf .common .quantization .structs import QuantizationPreset
33
- from nncf .common .quantization .structs import QuantizationScheme
34
- from nncf .experimental .torch .fx .nncf_graph_builder import GraphConverter
35
- from nncf .experimental .torch .fx .node_utils import get_graph_node_by_name
36
- from nncf .experimental .torch .fx .transformations import fold_constant_except_qdq
37
- from nncf .parameters import ModelType
38
- from nncf .parameters import QuantizationMode
39
- from nncf .parameters import TargetDevice
40
- from nncf .quantization .advanced_parameters import FP8QuantizationParameters
41
- from nncf .quantization .advanced_parameters import OverflowFix
42
- from nncf .quantization .advanced_parameters import QuantizationParameters
43
- from nncf .quantization .algorithms .min_max .algorithm import MinMaxQuantization
44
- from nncf .scopes import IgnoredScope
45
- from nncf .torch .model_graph_manager import get_weight_tensor_port_ids
46
31
47
32
QUANT_ANNOTATION_KEY = "quantization_annotation"
48
33
@@ -56,16 +41,15 @@ class OpenVINOQuantizer(Quantizer):
56
41
def __init__ (
57
42
self ,
58
43
* ,
59
- mode : Optional [QuantizationMode ] = None ,
60
- preset : Optional [QuantizationPreset ] = None ,
61
- target_device : TargetDevice = TargetDevice .ANY ,
62
- model_type : Optional [ ModelType ] = None ,
63
- ignored_scope : Optional [IgnoredScope ] = None ,
64
- overflow_fix : Optional [OverflowFix ] = None ,
44
+ mode : Optional [p . QuantizationMode ] = None ,
45
+ preset : Optional [q . structs . QuantizationPreset ] = None ,
46
+ target_device : p . TargetDevice = p . TargetDevice .ANY ,
47
+ transformer_model : bool = False ,
48
+ ignored_scope : Optional [nncf . IgnoredScope ] = None ,
49
+ overflow_fix : Optional [advanced_p . OverflowFix ] = None ,
65
50
quantize_outputs : bool = False ,
66
- activations_quantization_params : Optional [Union [QuantizationParameters , FP8QuantizationParameters ]] = None ,
67
- weights_quantization_params : Optional [Union [QuantizationParameters , FP8QuantizationParameters ]] = None ,
68
- quantizer_propagation_rule : QuantizerPropagationRule = QuantizerPropagationRule .MERGE_ALL_IN_ONE ,
51
+ activations_quantization_params : Optional [advanced_p .QuantizationParameters ] = None ,
52
+ weights_quantization_params : Optional [advanced_p .QuantizationParameters ] = None ,
69
53
):
70
54
"""
71
55
:param mode: Defines optimization mode for the algorithm. None by default.
@@ -89,29 +73,28 @@ def __init__(
89
73
:param activations_quantization_params: Quantization parameters for model
90
74
activations.
91
75
:param weights_quantization_params: Quantization parameters for model weights.
92
- :param quantizer_propagation_rule: The strategy to be used while propagating and merging quantizers.
93
- MERGE_ALL_IN_ONE by default.
94
76
"""
95
- self ._min_max_algo = MinMaxQuantization (
77
+ self ._min_max_algo = nncf . quantization . algorithms . min_max . algorithm . MinMaxQuantization (
96
78
mode = mode ,
97
79
preset = preset ,
98
80
target_device = target_device ,
99
- model_type = model_type ,
81
+ model_type = p . ModelType . TRANSFORMER if transformer_model else None ,
100
82
ignored_scope = ignored_scope ,
101
83
overflow_fix = overflow_fix ,
102
84
quantize_outputs = quantize_outputs ,
103
85
activations_quantization_params = activations_quantization_params ,
104
86
weights_quantization_params = weights_quantization_params ,
105
- quantizer_propagation_rule = quantizer_propagation_rule ,
106
87
)
107
88
108
- def get_quantization_setup (self , model : torch .fx .GraphModule , nncf_graph : NNCFGraph ) -> SingleConfigQuantizerSetup :
89
+ def get_nncf_quantization_setup (
90
+ self , model : torch .fx .GraphModule , nncf_graph : NNCFGraph
91
+ ) -> q .quantizer_setup .SingleConfigQuantizerSetup :
109
92
self ._min_max_algo ._set_backend_entity (model )
110
93
return self ._min_max_algo .find_quantization_setup (model , nncf_graph )
111
94
112
95
def annotate (self , model : torch .fx .GraphModule ) -> torch .fx .GraphModule :
113
- nncf_graph = GraphConverter .create_nncf_graph (model )
114
- quantization_setup = self .get_quantization_setup (model , nncf_graph )
96
+ nncf_graph = nncf_fx . nncf_graph_builder . GraphConverter .create_nncf_graph (model )
97
+ quantization_setup = self .get_nncf_quantization_setup (model , nncf_graph )
115
98
116
99
graph = model .graph
117
100
node_vs_torch_annotation = defaultdict (QuantizationAnnotation )
@@ -138,7 +121,9 @@ def annotate(self, model: torch.fx.GraphModule) -> torch.fx.GraphModule:
138
121
)
139
122
raise nncf .InternalError (msg )
140
123
141
- root_target_node = get_graph_node_by_name (graph , root_qp .insertion_point .target_node_name )
124
+ root_target_node = nncf_fx .node_utils .get_graph_node_by_name (
125
+ graph , root_qp .insertion_point .target_node_name
126
+ )
142
127
root_edge_or_node = self ._get_edge_or_node (root_target_node , root_qp , nncf_graph )
143
128
144
129
for quantizer_id in quantizer_ids :
@@ -155,10 +140,11 @@ def annotate(self, model: torch.fx.GraphModule) -> torch.fx.GraphModule:
155
140
for node , annotation in node_vs_torch_annotation .items ():
156
141
assert QUANT_ANNOTATION_KEY not in node .meta
157
142
node .meta [QUANT_ANNOTATION_KEY ] = annotation
143
+ return model
158
144
159
145
@staticmethod
160
146
def _get_unified_scales_root_quantizer_id (
161
- nncf_graph : NNCFGraph , quantizer_ids : List [int ], quantizer_setup : SingleConfigQuantizerSetup
147
+ nncf_graph : NNCFGraph , quantizer_ids : List [int ], quantizer_setup : q . quantizer_setup . SingleConfigQuantizerSetup
162
148
) -> int :
163
149
"""
164
150
Identifies the earliest quantizer node ID based on the corresponding `nncf_node.node_id`
@@ -184,7 +170,7 @@ def _get_unified_scales_root_quantizer_id(
184
170
def _get_edge_or_node_and_annotation (
185
171
graph : torch .fx .Graph ,
186
172
nncf_graph : NNCFGraph ,
187
- qp : QuantizationPointBase ,
173
+ qp : q . quantizer_setup . QuantizationPointBase ,
188
174
node_vs_torch_annotation : Dict [torch .fx .Node , QuantizationAnnotation ],
189
175
) -> Tuple [EdgeOrNode , QuantizationAnnotation ]:
190
176
"""
@@ -198,13 +184,15 @@ def _get_edge_or_node_and_annotation(
198
184
QuantizationAnnotations.
199
185
:return: A tuple containing the EdgeOrNode and its associated QuantizationAnnotation.
200
186
"""
201
- target_node = get_graph_node_by_name (graph , qp .insertion_point .target_node_name )
187
+ target_node = nncf_fx . node_utils . get_graph_node_by_name (graph , qp .insertion_point .target_node_name )
202
188
annotation = node_vs_torch_annotation [target_node ]
203
189
edge_or_node = OpenVINOQuantizer ._get_edge_or_node (target_node , qp , nncf_graph )
204
190
return edge_or_node , annotation
205
191
206
192
@staticmethod
207
- def _get_edge_or_node (target_node : torch .fx .Node , qp : QuantizationPointBase , nncf_graph : NNCFGraph ) -> EdgeOrNode :
193
+ def _get_edge_or_node (
194
+ target_node : torch .fx .Node , qp : q .quantizer_setup .QuantizationPointBase , nncf_graph : NNCFGraph
195
+ ) -> EdgeOrNode :
208
196
"""
209
197
Returns the edge or node based on the given target node and quantization point.
210
198
@@ -216,10 +204,10 @@ def _get_edge_or_node(target_node: torch.fx.Node, qp: QuantizationPointBase, nnc
216
204
ip = qp .insertion_point
217
205
if qp .is_weight_quantization_point ():
218
206
nncf_node = nncf_graph .get_node_by_name (target_node .name )
219
- weights_ports_ids = get_weight_tensor_port_ids (nncf_node , nncf_graph )
207
+ weights_ports_ids = nncf . torch . model_graph_manager . get_weight_tensor_port_ids (nncf_node , nncf_graph )
220
208
if len (weights_ports_ids ) > 1 :
221
209
# TODO(dlyakhov): support quantization for nodes with several weights
222
- nncf_logger .warning (
210
+ nncf . common . logging . nncf_logger .warning (
223
211
f"Quantization of the weighted node { target_node .name } "
224
212
" is not yet supported by the OpenVINOQuantizer."
225
213
f" Only the weight on port ID { weights_ports_ids [0 ]} will be quantized."
@@ -253,7 +241,7 @@ def _fill_torch_ao_annotation(
253
241
annotation_to_update .input_qspec_map [edge_or_node [0 ]] = qspec
254
242
255
243
@staticmethod
256
- def _get_torch_ao_qspec_from_qp (qp : QuantizationPointBase ) -> QuantizationSpec :
244
+ def _get_torch_ao_qspec_from_qp (qp : q . quantizer_setup . QuantizationPointBase ) -> QuantizationSpec :
257
245
"""
258
246
Retrieves the quantization configuration from the given quantization point and
259
247
converts it into a QuantizationSpec.
@@ -269,15 +257,16 @@ def _get_torch_ao_qspec_from_qp(qp: QuantizationPointBase) -> QuantizationSpec:
269
257
if qconfig .per_channel :
270
258
torch_qscheme = (
271
259
torch .per_channel_symmetric
272
- if qconfig .mode is QuantizationScheme .SYMMETRIC
260
+ if qconfig .mode is q . structs . QuantizationScheme .SYMMETRIC
273
261
else torch .per_channel_affine
274
262
)
275
263
else :
276
264
torch_qscheme = (
277
- torch .per_tensor_symmetric if qconfig .mode is QuantizationScheme .SYMMETRIC else torch .per_tensor_affine
265
+ torch .per_tensor_symmetric
266
+ if qconfig .mode is q .structs .QuantizationScheme .SYMMETRIC
267
+ else torch .per_tensor_affine
278
268
)
279
269
if is_weight :
280
- observer = PerChannelMinMaxObserver if qconfig .per_channel else MinMaxObserver
281
270
observer = PerChannelMinMaxObserver
282
271
quant_min = - 128
283
272
quant_max = 127
@@ -307,5 +296,5 @@ def validate(self, model: torch.fx.GraphModule) -> None:
307
296
pass
308
297
309
298
def transform_for_annotation (self , model : torch .fx .GraphModule ) -> torch .fx .GraphModule :
310
- fold_constant_except_qdq (model )
299
+ nncf_fx . transformations . fold_constant_except_qdq (model )
311
300
return model
0 commit comments