Skip to content

Commit 06d4589

Browse files
authored
Remove quantization functions from node quantization configs (#1477)
* remove activation_quantization_fn and activation_quantization_params_fn from NodeActivationQuantizationCfg * remove weights_quantization_fn and weights_quantization_params_fn from WeightsAttrQuantizationConfig
1 parent fad922f commit 06d4589

32 files changed

Lines changed: 377 additions & 588 deletions

File tree

docsrc/source/api/api_docs/modules/network_editor.rst

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -46,10 +46,6 @@ Actions
4646

4747
|
4848
49-
.. autoclass:: model_compression_toolkit.core.network_editor.ChangeQuantizationParamFunction
50-
51-
|
52-
5349
.. autoclass:: model_compression_toolkit.core.network_editor.ChangeFinalWeightsQuantizationMethod
5450

5551
|

model_compression_toolkit/core/common/framework_info.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -52,19 +52,19 @@ class FrameworkInfo(ABC):
5252
no_quantization_ops:Layers that should not get quantized (e.g., Reshape, Transpose, etc.)
5353
5454
Fields:
55-
activation_quantizer_mapping (Dict[QuantizationMethod, Callable]): A dictionary mapping from QuantizationMethod to a quantization function.
5655
kernel_channels_mapping (Dict): Dictionary from a layer to a tuple of its kernel in/out channels indices.
5756
kernel_ops_attribute_mapping (Dict): Dictionary from a framework operator to its weight attribute to quantize.
5857
out_channel_axis_mapping (Dict): Dictionary of output channels of the model's layers (for computing statistics per-channel).
5958
_layer_min_max_mapping (Dict[Any, tuple]): Dictionary from a layer to its min/max output values.
59+
activation_quantizer_factory_mapping: A mapping from QuantizationMethod to a factory function that accepts activation bitwidth and a dict of quantization params, and returns the corresponding quantization function.
6060
"""
6161

62-
activation_quantizer_mapping: Dict[QuantizationMethod, Callable]
63-
kernel_channels_mapping: Dict[Any, ChannelAxisMapping]
6462
kernel_ops_attribute_mapping: Dict[Any, str]
63+
kernel_channels_mapping: Dict[Any, ChannelAxisMapping]
6564
out_channel_axis_mapping: Dict[Any, int]
66-
_layer_min_max_mapping: Dict[Any, tuple]
65+
activation_quantizer_factory_mapping: Dict[QuantizationMethod, Callable[[int, dict], Callable]]
6766

67+
_layer_min_max_mapping: Dict[Any, tuple]
6868
_default_channel_mapping = ChannelAxisMapping(None, None)
6969

7070
@classmethod

model_compression_toolkit/core/common/graph/base_graph.py

Lines changed: 1 addition & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -873,11 +873,7 @@ def override_fused_node_activation_quantization_candidates(self):
873873
fusing_op_quantization_cfg = self.fusing_info.get_fused_op_quantization_config(fused_node_op_id)
874874
if fusing_op_quantization_cfg is not None and fusing_op_quantization_cfg.enable_activation_quantization:
875875
def update(qc):
876-
qc.activation_quantization_cfg = NodeActivationQuantizationConfig(
877-
fusing_op_quantization_cfg,
878-
qc.activation_quantization_cfg.activation_quantization_fn,
879-
qc.activation_quantization_cfg.activation_quantization_params_fn
880-
)
876+
qc.activation_quantization_cfg = NodeActivationQuantizationConfig(fusing_op_quantization_cfg)
881877
qc.activation_quantization_cfg.quant_mode = ActivationQuantizationMode.FLN_QUANT
882878
node.quantization_cfg.update_all(update)
883879
node.quantization_cfg.remove_duplicates()

model_compression_toolkit/core/common/mixed_precision/configurable_quantizer_utils.py

Lines changed: 11 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,8 @@
1818

1919
from model_compression_toolkit.core.common.quantization.candidate_node_quantization_config import \
2020
CandidateNodeQuantizationConfig
21+
from model_compression_toolkit.core.common.quantization.quantization_fn_selection import (get_activation_quantization_fn,
22+
get_weights_quantization_fn)
2123

2224

2325
def verify_candidates_descending_order(node_q_cfg: List[CandidateNodeQuantizationConfig],
@@ -77,13 +79,13 @@ def init_quantized_weights(node_q_cfg: List[CandidateNodeQuantizationConfig],
7779
quantized_weights = []
7880
for qc in node_q_cfg:
7981
qc_weights_attr = qc.weights_quantization_cfg.get_attr_config(kernel_attr)
80-
q_weight = qc_weights_attr.weights_quantization_fn(float_weights,
81-
qc_weights_attr.weights_n_bits,
82-
True,
83-
qc_weights_attr.weights_quantization_params,
84-
qc_weights_attr.weights_per_channel_threshold,
85-
qc_weights_attr.weights_channels_axis[
86-
0]) # output channel axis
82+
weights_quantization_fn = get_weights_quantization_fn(qc_weights_attr.weights_quantization_method)
83+
q_weight = weights_quantization_fn(float_weights,
84+
qc_weights_attr.weights_n_bits,
85+
True,
86+
qc_weights_attr.weights_quantization_params,
87+
qc_weights_attr.weights_per_channel_threshold,
88+
qc_weights_attr.weights_channels_axis[0]) # output channel axis
8789

8890
quantized_weights.append(fw_tensor_convert_func(q_weight))
8991

@@ -105,6 +107,7 @@ def init_activation_quantizers(node_q_cfg: List[CandidateNodeQuantizationConfig]
105107
activation_quantizers = []
106108
for index, qc in enumerate(node_q_cfg):
107109
q_activation = node_q_cfg[index].activation_quantization_cfg
108-
activation_quantizers.append(q_activation.quantize_node_output)
110+
quantizer = get_activation_quantization_fn(q_activation)
111+
activation_quantizers.append(quantizer)
109112

110113
return activation_quantizers

model_compression_toolkit/core/common/network_editors/__init__.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,14 @@
1313
# limitations under the License.
1414
# ==============================================================================
1515

16-
from model_compression_toolkit.core.common.network_editors.actions import ChangeCandidatesWeightsQuantConfigAttr, ChangeFinalWeightsQuantConfigAttr, ChangeCandidatesActivationQuantConfigAttr, ChangeQuantizationParamFunction, ChangeCandidatesActivationQuantizationMethod, ChangeFinalWeightsQuantizationMethod, ChangeCandidatesWeightsQuantizationMethod, ChangeFinalActivationQuantConfigAttr
16+
from model_compression_toolkit.core.common.network_editors.actions import (
17+
ChangeCandidatesWeightsQuantConfigAttr,
18+
ChangeFinalWeightsQuantConfigAttr,
19+
ChangeCandidatesActivationQuantConfigAttr,
20+
ChangeCandidatesActivationQuantizationMethod,
21+
ChangeFinalWeightsQuantizationMethod,
22+
ChangeCandidatesWeightsQuantizationMethod,
23+
ChangeFinalActivationQuantConfigAttr)
1724
from model_compression_toolkit.core.common.network_editors.actions import EditRule
1825
from model_compression_toolkit.core.common.network_editors.node_filters import NodeTypeFilter, NodeNameScopeFilter, \
1926
NodeNameFilter

model_compression_toolkit/core/common/network_editors/actions.py

Lines changed: 3 additions & 95 deletions
Original file line numberDiff line numberDiff line change
@@ -20,14 +20,8 @@
2020
from mct_quantizers import QuantizationMethod
2121
from model_compression_toolkit.core.common import Graph
2222
from model_compression_toolkit.logger import Logger
23-
24-
25-
from model_compression_toolkit.core.common.framework_info import get_fw_info
2623
from model_compression_toolkit.core.common.graph.base_node import BaseNode
27-
from model_compression_toolkit.core.common.quantization.quantization_params_fn_selection import \
28-
get_activation_quantization_params_fn, get_weights_quantization_params_fn
29-
from model_compression_toolkit.core.common.quantization.quantization_fn_selection import \
30-
get_weights_quantization_fn
24+
3125

3226
_EditRule = namedtuple('EditRule', 'filter action')
3327

@@ -174,47 +168,6 @@ def apply(self, node: BaseNode, graph):
174168
node.final_activation_quantization_cfg.set_quant_config_attr(parameter_name, parameter_value)
175169

176170

177-
class ChangeQuantizationParamFunction(BaseAction):
178-
"""
179-
Class ChangeQuantizationParamFunction to change a node's weights/activations quantization params function.
180-
"""
181-
182-
def __init__(self,
183-
attr_name: str = None,
184-
activation_quantization_params_fn: Callable = None,
185-
weights_quantization_params_fn: Callable = None):
186-
"""
187-
Init a ChangeQuantizationParamFunction object.
188-
189-
Args:
190-
attr_name: The weights attribute's name to set the weights quantization params function for (if setting weights params).
191-
activation_quantization_params_fn: a params function for a node's activations.
192-
weights_quantization_params_fn: a params function for a node's weights.
193-
"""
194-
self.activation_quantization_params_fn = activation_quantization_params_fn
195-
self.weights_quantization_params_fn = weights_quantization_params_fn
196-
self.attr_name = attr_name
197-
198-
def apply(self, node: BaseNode, graph):
199-
"""
200-
Change the node's weights/activations quantization params function.
201-
202-
Args:
203-
node: Node object to change its quantization params function.
204-
graph: Graph to apply the action on.
205-
206-
Returns:
207-
The node after its quantization params function has been modified.
208-
"""
209-
for nqc in node.candidates_quantization_cfg:
210-
if self.activation_quantization_params_fn is not None:
211-
nqc.activation_quantization_cfg.set_activation_quantization_params_fn(
212-
self.activation_quantization_params_fn)
213-
if self.weights_quantization_params_fn is not None:
214-
attr_config = nqc.weights_quantization_cfg.get_attr_config(self.attr_name)
215-
attr_config.override_weights_quantization_params_fn(self.weights_quantization_params_fn)
216-
217-
218171
class ChangeFinalActivationQuantizationMethod(BaseAction):
219172
"""
220173
Class ChangeFinalActivationQuantizationMethod to change a node's weights/activations quantizer function.
@@ -243,16 +196,6 @@ def apply(self, node: BaseNode, graph):
243196
"""
244197

245198
if self.activation_quantization_method is not None and node.final_activation_quantization_cfg is not None:
246-
247-
activation_quantization_params_fn = get_activation_quantization_params_fn(
248-
self.activation_quantization_method)
249-
250-
node.final_activation_quantization_cfg.set_activation_quantization_params_fn(
251-
activation_quantization_params_fn)
252-
253-
activation_quantization_fn = get_fw_info().activation_quantizer_mapping.get(self.activation_quantization_method)
254-
255-
node.final_activation_quantization_cfg.set_activation_quantization_fn(activation_quantization_fn)
256199
node.final_activation_quantization_cfg.activation_quantization_method = self.activation_quantization_method
257200

258201

@@ -281,23 +224,12 @@ def apply(self, node: BaseNode, graph):
281224
"""
282225
if self.activation_quantization_method is not None:
283226
for qc in node.candidates_quantization_cfg:
284-
activation_quantization_params_fn = get_activation_quantization_params_fn(
285-
self.activation_quantization_method)
286-
287-
qc.activation_quantization_cfg.set_activation_quantization_params_fn(activation_quantization_params_fn)
288-
activation_quantization_fn = get_fw_info().activation_quantizer_mapping.get(
289-
self.activation_quantization_method)
290-
291-
if activation_quantization_fn is None:
292-
Logger.critical('Unknown activation quantization method specified.') # pragma: no cover
293-
294-
qc.activation_quantization_cfg.set_activation_quantization_fn(activation_quantization_fn)
295227
qc.activation_quantization_cfg.activation_quantization_method = self.activation_quantization_method
296228

297229

298230
class ChangeFinalWeightsQuantizationMethod(BaseAction):
299231
"""
300-
Class ChangeFinalWeightsQuantizationMethod to change a node's weights/activations quantizer function.
232+
Class ChangeFinalWeightsQuantizationMethod to change a node's weights/activations quantizer method.
301233
"""
302234

303235
def __init__(self, attr_name: str, weights_quantization_method=None):
@@ -323,21 +255,8 @@ def apply(self, node: BaseNode, graph):
323255
"""
324256

325257
if self.weights_quantization_method is not None and node.final_weights_quantization_cfg is not None:
326-
327-
weights_quantization_params_fn = get_weights_quantization_params_fn(self.weights_quantization_method)
328-
329-
attr_config = node.final_weights_quantization_cfg.get_attr_config(self.attr_name)
330-
attr_config.override_weights_quantization_params_fn(weights_quantization_params_fn)
331-
332-
weights_quantization_fn = get_weights_quantization_fn(self.weights_quantization_method)
333-
334-
if weights_quantization_fn is None:
335-
Logger.critical('Unknown weights quantization method specified.') # pragma: no cover
336-
337258
attr_config = node.final_weights_quantization_cfg.get_attr_config(self.attr_name)
338-
attr_config.override_weights_quantization_fn(weights_quantization_fn)
339-
node.final_weights_quantization_cfg.get_attr_config(self.attr_name).weights_quantization_method = \
340-
self.weights_quantization_method
259+
attr_config.weights_quantization_method = self.weights_quantization_method
341260

342261

343262
class ChangeCandidatesWeightsQuantizationMethod(BaseAction):
@@ -370,18 +289,7 @@ def apply(self, node: BaseNode, graph: Graph):
370289

371290
if self.weights_quantization_method is not None:
372291
for qc in node.candidates_quantization_cfg:
373-
374-
weights_quantization_params_fn = get_weights_quantization_params_fn(self.weights_quantization_method)
375-
376292
attr_qc = qc.weights_quantization_cfg.get_attr_config(self.attr_name)
377-
attr_qc.override_weights_quantization_params_fn(weights_quantization_params_fn)
378-
379-
weights_quantization_fn = get_weights_quantization_fn(self.weights_quantization_method)
380-
381-
if weights_quantization_fn is None:
382-
Logger.critical('Unknown weights quantization method specified.') # pragma: no cover
383-
384-
attr_qc.override_weights_quantization_fn(weights_quantization_fn)
385293
attr_qc.weights_quantization_method = self.weights_quantization_method
386294

387295

0 commit comments

Comments
 (0)