Skip to content

Commit 8621d47

Browse files
authored
Fix activation bias correction with preserving node (#1449)
* Fix activation bias correction with preserving node
1 parent cc8189d commit 8621d47

4 files changed

Lines changed: 13 additions & 7 deletions

File tree

model_compression_toolkit/core/common/mixed_precision/bit_width_setter.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -76,7 +76,8 @@ def set_bit_widths(mixed_precision_enable: bool,
7676
for n in graph.nodes:
7777
assert len(n.candidates_quantization_cfg) == 1
7878
n.final_weights_quantization_cfg = copy.deepcopy(n.candidates_quantization_cfg[0].weights_quantization_cfg)
79-
n.final_activation_quantization_cfg = copy.deepcopy(n.candidates_quantization_cfg[0].activation_quantization_cfg)
79+
if not n.is_quantization_preserving():
80+
n.final_activation_quantization_cfg = copy.deepcopy(n.candidates_quantization_cfg[0].activation_quantization_cfg)
8081

8182
return graph
8283

model_compression_toolkit/core/common/statistics_correction/compute_activation_bias_correction_of_graph.py

Lines changed: 2 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -42,13 +42,9 @@ def get_previous_node_with_activation_quantization(linear_node: BaseNode,
4242

4343
prev_node = prev_nodes[0]
4444

45-
activation_quantization_config = prev_node.final_activation_quantization_cfg
45+
prev_quant_node = graph.retrieve_preserved_quantization_node(prev_node)
4646

47-
# Search for node with activation quantization
48-
if activation_quantization_config.enable_activation_quantization:
49-
return prev_node
50-
else:
51-
return get_previous_node_with_activation_quantization(prev_node, graph)
47+
return prev_quant_node if prev_quant_node.is_activation_quantization_enabled() else None
5248

5349

5450
def calculate_bin_centers(bin_edges: np.ndarray) -> np.ndarray:

tests/keras_tests/feature_networks_tests/feature_networks/activation_bias_correction_test.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
import numpy as np
2121

2222
from tests.keras_tests.utils import get_layers_from_model_by_type
23+
from tests.common_tests.helpers.tpcs_for_tests.v4.tpc import get_tpc
2324

2425
keras = tf.keras
2526
layers = keras.layers
@@ -52,6 +53,10 @@ def get_quantization_config(self):
5253
activation_bias_correction=True,
5354
activation_bias_correction_threshold=self.activation_bias_correction_threshold)
5455

56+
def get_tpc(self):
57+
tpc = get_tpc()
58+
return tpc
59+
5560
def create_networks(self):
5661
inputs = layers.Input(shape=self.get_input_shapes()[0][1:])
5762
x = self.prev_layer(inputs)

tests/pytorch_tests/model_tests/feature_models/activation_bias_correction_test.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -106,6 +106,10 @@ def get_quantization_config(self):
106106
def create_networks(self):
107107
return self.model
108108

109+
def get_tpc(self):
110+
from tests.common_tests.helpers.tpcs_for_tests.v4.tpc import get_tpc
111+
return get_tpc()
112+
109113
def compare(self, quantized_model, float_model, input_x=None, quantization_info=None):
110114
bias = float_model.linear_layer.bias.cpu().detach().numpy()
111115
bias_after_activation_bias_correction = quantized_model.linear_layer.layer.bias.cpu().detach().numpy()

0 commit comments

Comments
 (0)