Skip to content

Commit 7d0b412

Browse files
irenabirenab
authored andcommitted
remove weights_second_moment_correction and weights_bias_correction from NodeWeightsQuantizationConfig
1 parent 39c8d63 commit 7d0b412

10 files changed

Lines changed: 22 additions & 58 deletions

File tree

model_compression_toolkit/core/common/quantization/node_quantization_config.py

Lines changed: 5 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,6 @@
1818
from model_compression_toolkit.core.common.framework_info import ChannelAxisMapping
1919
from model_compression_toolkit.logger import Logger
2020

21-
from model_compression_toolkit.core.common.quantization.quantization_config import QuantizationConfig
2221
from model_compression_toolkit.target_platform_capabilities.constants import POSITIONAL_ATTR
2322
from model_compression_toolkit.target_platform_capabilities.schema.mct_current_schema import \
2423
AttributeQuantizationConfig, OpQuantizationConfig
@@ -64,6 +63,7 @@ def set_quant_config_attr(self, config_parameter_name: str, config_parameter_val
6463
if hasattr(self, config_parameter_name):
6564
setattr(self, config_parameter_name, config_parameter_value)
6665
else:
66+
raise AttributeError(config_parameter_name)
6767
Logger.warning(f"Parameter {config_parameter_name} could not be found in the node quantization config and "
6868
f"was not updated!")
6969

@@ -272,14 +272,11 @@ def __init__(self,
272272

273273
self.attributes_config_mapping[attr] = WeightsAttrQuantizationConfig(weights_attr_cfg=attr_cfg,
274274
weights_channels_axis=weights_channels_axis)
275-
# TODO irena remove along with set_qc. Keeping for eq and hash to work without set_qc being called
275+
# TODO this is set by batch norm reconstruction substitution when folded batch norms are added back, to mark
276+
# the nodes that the correction should be applied to (for some nodes it gets disabled) and BNs removed.
277+
# The actual correction is only computed when it's applied in ptq, so it seems that both substitutions could
278+
# be unified, and no info need to pass between.
276279
self.weights_second_moment_correction = None
277-
self.weights_bias_correction = None
278-
279-
def set_qc(self, qc: QuantizationConfig):
280-
# TODO irena: temporary keep the fields to not break everything at once.
281-
self.weights_second_moment_correction = qc.weights_second_moment_correction
282-
self.weights_bias_correction = qc.weights_bias_correction
283280

284281
def get_attr_config(self, attr_name: 'WeightAttrT') -> WeightsAttrQuantizationConfig:
285282
"""
@@ -435,8 +432,6 @@ def __eq__(self, other: Any) -> bool:
435432
return False # pragma: no cover
436433

437434
return self.simd_size == other.simd_size and \
438-
self.weights_second_moment_correction == other.weights_second_moment_correction and \
439-
self.weights_bias_correction == other.weights_bias_correction and \
440435
self.attributes_config_mapping.keys() == other.attributes_config_mapping.keys() and \
441436
all([self.attributes_config_mapping[k] == other.attributes_config_mapping[k]
442437
for k in self.attributes_config_mapping.keys()]) and \
@@ -446,7 +441,5 @@ def __eq__(self, other: Any) -> bool:
446441

447442
def __hash__(self):
448443
return hash((self.simd_size,
449-
self.weights_second_moment_correction,
450-
self.weights_bias_correction,
451444
frozenset(self.attributes_config_mapping),
452445
frozenset(self.pos_attributes_config_mapping)))

model_compression_toolkit/core/common/statistics_correction/apply_bias_correction_to_graph.py

Lines changed: 3 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -14,24 +14,20 @@
1414
# ==============================================================================
1515
import copy
1616

17-
from model_compression_toolkit.core.common.quantization.quantization_config import QuantizationConfig
18-
from model_compression_toolkit.core import CoreConfig
1917
from model_compression_toolkit.core.common import Graph, BaseNode
2018
from model_compression_toolkit.core.common.framework_implementation import FrameworkImplementation
2119
from model_compression_toolkit.core.common.quantization.node_quantization_config import WeightsAttrQuantizationConfig
2220
from model_compression_toolkit.target_platform_capabilities.schema.mct_current_schema import AttributeQuantizationConfig
2321

2422

2523
def apply_bias_correction_to_graph(graph_to_apply_bias_correction: Graph,
26-
core_config: CoreConfig,
2724
fw_impl: FrameworkImplementation) -> Graph:
2825
"""
2926
Get a graph, where each node has a final weights quantization configuration (with a bias
3027
correction term in it), and apply the bias correction for each node in the graph.
3128
3229
Args:
3330
graph_to_apply_bias_correction: Graph to apply bias correction to.
34-
core_config: CoreConfig containing parameters of how the model should be quantized.
3531
fw_impl: FrameworkImplementation object with a specific framework methods implementation.
3632
3733
Returns:
@@ -41,19 +37,16 @@ def apply_bias_correction_to_graph(graph_to_apply_bias_correction: Graph,
4137
graph = copy.deepcopy(graph_to_apply_bias_correction)
4238
for n in graph.nodes:
4339
# bias correction is only relevant for nodes with kernel op
44-
if core_config.quantization_config.weights_bias_correction and n.kernel_attr is not None and \
45-
n.is_weights_quantization_enabled(n.kernel_attr) and \
40+
if n.kernel_attr is not None and n.is_weights_quantization_enabled(n.kernel_attr) and \
4641
not n.final_weights_quantization_cfg.weights_second_moment_correction:
4742
# If a kernel was quantized and weights bias correction is enabled in n.quantization_cfg,
4843
# a bias correction term was calculated during model preparation, and is used now in the node's bias term.
49-
if n.final_weights_quantization_cfg.weights_bias_correction:
50-
_apply_bias_correction_to_node(n, fw_impl, core_config.quantization_config)
44+
_apply_bias_correction_to_node(n, fw_impl)
5145
return graph
5246

5347

5448
def _apply_bias_correction_to_node(node: BaseNode,
55-
fw_impl: FrameworkImplementation,
56-
qc: QuantizationConfig):
49+
fw_impl: FrameworkImplementation):
5750
"""
5851
Set new bias to node using the bias-correction term that is stored in the
5952
final weights quantization configuration.

model_compression_toolkit/core/common/statistics_correction/compute_bias_correction_of_graph.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -74,8 +74,7 @@ def _compute_bias_correction_per_candidate_qc(node: BaseNode,
7474
"""
7575

7676
for candidate_qc in node.candidates_quantization_cfg:
77-
if candidate_qc.weights_quantization_cfg.weights_bias_correction and not \
78-
candidate_qc.weights_quantization_cfg.weights_second_moment_correction:
77+
if not candidate_qc.weights_quantization_cfg.weights_second_moment_correction:
7978

8079
quantized_kernel, io_channels_axes = get_quantized_weights_attr_by_qc(kernel_attr,
8180
node,

model_compression_toolkit/core/common/statistics_correction/statistics_correction.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -56,8 +56,9 @@ def statistics_correction_runner(transformed_graph: Graph,
5656
########################################################
5757
# Compute bias correction to nodes' config candidates
5858
########################################################
59-
tg_with_bias = compute_bias_correction_of_graph(tg_with_bias,
60-
fw_impl)
59+
if core_config.quantization_config.weights_bias_correction:
60+
tg_with_bias = compute_bias_correction_of_graph(tg_with_bias,
61+
fw_impl)
6162

6263
if tb_w is not None:
6364
tb_w.add_graph(tg_with_bias, 'statistics_computation')
@@ -96,7 +97,6 @@ def apply_statistics_correction(transformed_graph: Graph,
9697
#############################################
9798
if core_config.quantization_config.weights_bias_correction:
9899
transformed_graph = apply_bias_correction_to_graph(transformed_graph,
99-
core_config,
100100
fw_impl=fw_impl)
101101
if tb_w is not None:
102102
tb_w.add_graph(transformed_graph, 'after_statistics_correction')

model_compression_toolkit/core/common/substitutions/batchnorm_reconstruction.py

Lines changed: 5 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,6 @@
2020
import numpy as np
2121

2222
from model_compression_toolkit.core.common import Graph
23-
from model_compression_toolkit.core.common.quantization.quantization_config import QuantizationConfig
2423
from model_compression_toolkit.core import common
2524
from model_compression_toolkit.core.common.quantization.node_quantization_config import WeightsAttrQuantizationConfig, \
2625
ActivationQuantizationMode
@@ -84,14 +83,10 @@ def substitute(self,
8483
# If the linear operator is part of a reused group (it is the "base" node, or a reused node),
8584
# we should skip the substitution.
8685
if source_node.is_reused():
87-
for qc in source_node.candidates_quantization_cfg:
88-
qc.weights_quantization_cfg.weights_second_moment_correction = False
8986
return graph
9087

9188
# We apply only on nodes with folded BatchNormalization.
9289
if source_node.prior_info.std_output is None or source_node.prior_info.mean_output is None:
93-
for qc in source_node.candidates_quantization_cfg:
94-
qc.weights_quantization_cfg.weights_second_moment_correction = False
9590
return graph
9691

9792
# This feature disabled for models with weights quantization method of Power of 2
@@ -103,10 +98,13 @@ def substitute(self,
10398
== QuantizationMethod.POWER_OF_TWO):
10499
Logger.warning("Second moment statistics correction feature disabled for models with weights "
105100
"quantization method of Power of 2")
106-
for qc_inner in source_node.candidates_quantization_cfg:
107-
qc_inner.weights_quantization_cfg.weights_second_moment_correction = False
108101
return graph
109102

103+
# turn on second moment correction flag
104+
def set_second_moment_correction(qc):
105+
qc.weights_quantization_cfg.weights_second_moment_correction = True
106+
source_node.quantization_cfg.update_all(set_second_moment_correction)
107+
110108
eps = self.epsilon_val
111109

112110
original_gamma = source_node.prior_info.std_output

model_compression_toolkit/core/graph_prep_runner.py

Lines changed: 1 addition & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -153,16 +153,10 @@ def get_finalized_graph(initial_graph: Graph,
153153
if bit_width_config:
154154
set_manual_bitwidth_config(graph, bit_width_config)
155155

156-
# TODO irena: load_fqc_configuration only loads config from tpc. Previously quant_config was read as well.
157-
# As a first stage we keep the attributes in internal configs and fill them manually from quant_config
158-
# not to break all the code at once. Eventually we need to handle quant_config directly, without injecting into candidates.
159-
# TODO 2: Also we adjust candidates for single precision, which we shouldn't do here.
160-
def update(qc):
161-
qc.weights_quantization_cfg.set_qc(quant_config)
156+
# TODO irena: remove after base config is used
162157
for n in transformed_graph.nodes:
163158
if not mixed_precision_enable:
164159
n.quantization_cfg.candidates_quantization_cfg = [n.quantization_cfg.base_quantization_cfg]
165-
n.quantization_cfg.update_all(update)
166160

167161
######################################
168162
# Channel equalization

tests/keras_tests/feature_networks_tests/feature_networks/network_editor/edit_qc_test.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -224,7 +224,7 @@ class ChangeCandidatesWeightsQuantConfigAttrTest(BaseChangeQuantConfigAttrTest):
224224

225225
def __init__(self, unit_test):
226226
edit_filter = NodeTypeFilter(layers.Conv2D)
227-
action = ChangeCandidatesWeightsQuantConfigAttr(weights_bias_correction=False)
227+
action = ChangeCandidatesWeightsQuantConfigAttr(weights_second_moment_correction=True)
228228
prepare_graph_func = prepare_graph_for_first_network_editor
229229
super().__init__(unit_test, edit_filter=edit_filter, action=action, prepare_graph_func=prepare_graph_func)
230230

@@ -242,7 +242,7 @@ class ChangeFinalsWeightsQuantConfigAttrTest(BaseChangeQuantConfigAttrTest):
242242

243243
def __init__(self, unit_test):
244244
edit_filter = NodeTypeFilter(layers.Conv2D)
245-
action = ChangeFinalWeightsQuantConfigAttr(weights_bias_correction=False)
245+
action = ChangeFinalWeightsQuantConfigAttr(weights_second_moment_correction=True)
246246
prepare_graph_func = prepare_graph_for_second_network_editor
247247
super().__init__(unit_test, edit_filter=edit_filter, action=action, prepare_graph_func=prepare_graph_func)
248248

tests/keras_tests/feature_networks_tests/test_features_runner.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -182,7 +182,8 @@ def test_depthwise_conv2d_replacement(self):
182182
DwConv2dReplacementTest(self).run_test()
183183

184184
def test_change_qc_attr(self):
185-
ChangeFinalWeightQCAttrTest(self).run_test()
185+
# there are no fields that can be changed in final cfg and have any effect (unless the whole attr cfgs mapping is overridden)
186+
# ChangeFinalWeightQCAttrTest(self).run_test()
186187
ChangeFinalActivationQCAttrTest(self).run_test()
187188

188189
def test_edit_candidate_qc(self):

tests/keras_tests/function_tests/test_node_quantization_configurations.py

Lines changed: 0 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -53,15 +53,6 @@ def test_weights_set_quant_config_attribute(self):
5353
node_attrs_list=[KERNEL, 0])
5454
og_nwc = copy.deepcopy(nwc)
5555

56-
# Updating a config parameter, not weights attribute parameter (no attr_name passed)
57-
# TODO irena: weights_bias_correction should be removed
58-
# self.assertTrue(nwc.weights_bias_correction)
59-
nwc.set_quant_config_attr("weights_bias_correction", False)
60-
self.assertFalse(nwc.weights_bias_correction)
61-
self.assertFalse(nwc == og_nwc)
62-
63-
nwc = copy.deepcopy(og_nwc)
64-
6556
# Updating an attribute parameter
6657
self.assertTrue(nwc.get_attr_config(KERNEL).weights_n_bits, 8)
6758
nwc.set_quant_config_attr("weights_n_bits", 4, attr_name=KERNEL)

tests/keras_tests/non_parallel_tests/test_lp_search_bitwidth.py

Lines changed: 0 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -72,13 +72,8 @@ def dummy_representative_dataset():
7272
graph = load_fqc_configuration(graph=graph, fqc=fqc)
7373

7474
for node in graph.nodes:
75-
# TODO irena remove set_qc:
76-
for c in node.quantization_cfg.candidates_quantization_cfg:
77-
c.weights_quantization_cfg.set_qc(core_config.quantization_config)
78-
7975
node.prior_info = keras_impl.get_node_prior_info(node=node,
8076
graph=graph)
81-
8277
mi = ModelCollector(graph,
8378
fw_impl=keras_impl,
8479
qc=core_config.quantization_config)

0 commit comments

Comments
 (0)