Skip to content

Commit 532f6ee

Browse files
Node Supported Bitwidth Filtering Bugfix (#1422)
* bugfix: filter out nodes without enable quantization and not preserving * bugfix: move add-node qc in shift negative substitution to be before padding node * bugfix: is preserving validation bugfix
1 parent 8597082 commit 532f6ee

File tree

3 files changed

+98
-25
lines changed

3 files changed

+98
-25
lines changed

model_compression_toolkit/core/common/quantization/set_node_quantization_config.py

Lines changed: 10 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -119,11 +119,16 @@ def filter_node_qco_by_graph(node: BaseNode,
119119
_next_nodes.extend(graph.get_next_nodes(n))
120120
next_nodes.append(n)
121121

122-
if len(next_nodes):
123-
next_nodes_qc_options = [_node.get_qco(fqc) for _node in next_nodes]
124-
next_nodes_supported_input_bitwidth = min([max_input_activation_n_bits(op_cfg)
122+
if len(next_nodes) == 0:
123+
return _base_config, _node_qc_options
124+
next_nodes_qc_options = [_node.get_qco(fqc) for _node in next_nodes]
125+
all_next_nodes_supported_input_bitwidth = [max_input_activation_n_bits(op_cfg)
125126
for qc_opts in next_nodes_qc_options
126-
for op_cfg in qc_opts.quantization_configurations])
127+
for op_cfg in qc_opts.quantization_configurations
128+
if op_cfg.enable_activation_quantization or op_cfg.quantization_preserving
129+
]
130+
if len(all_next_nodes_supported_input_bitwidth):
131+
next_nodes_supported_input_bitwidth = min(all_next_nodes_supported_input_bitwidth)
127132

128133
# Filter node's QC options that match next nodes input bit-width.
129134
_node_qc_options = [_option for _option in _node_qc_options
@@ -205,7 +210,7 @@ def set_quantization_configs_to_node(node: BaseNode,
205210
# Preserving the quantization of more than 1 previous node is ambiguous, so disable it.
206211
Logger.info(f"Disabling Quantization-Preserving for node {node.name} because it has more than 1 input activations.")
207212
candidate_qc.activation_quantization_cfg.quant_mode = ActivationQuantizationMode.NO_QUANT
208-
elif not prev_nodes[0].is_quantization_preserving() or not prev_nodes[0].is_activation_quantization_enabled():
213+
elif not prev_nodes[0].is_quantization_preserving() and not prev_nodes[0].is_activation_quantization_enabled():
209214
# Preserving the quantization of an unquantized node isn't possible, so disable it.
210215
Logger.info(f"Disabling Quantization-Preserving for node {node.name} because previous node activation quantization is disabled.")
211216
candidate_qc.activation_quantization_cfg.quant_mode = ActivationQuantizationMode.NO_QUANT

model_compression_toolkit/core/common/substitutions/shift_negative_activation.py

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -343,6 +343,13 @@ def shift_negative_function(graph: Graph,
343343
graph.set_out_stats_collector_to_node(add_node, add_node_stats_collector)
344344
graph.shift_stats_collector(add_node, np.array(shift_value))
345345

346+
set_quantization_configs_to_node(fw_info=fw_info,
347+
node=add_node,
348+
graph=graph,
349+
quant_config=core_config.quantization_config,
350+
fqc=graph.fqc,
351+
mixed_precision_enable=core_config.is_mixed_precision_enabled)
352+
346353
if padding is not None:
347354
pad_node = create_pad_node(op2d_node.name,
348355
add_node.name,
@@ -373,13 +380,6 @@ def shift_negative_function(graph: Graph,
373380

374381
op2d_node.input_shape = pad_node.output_shape
375382

376-
set_quantization_configs_to_node(fw_info=fw_info,
377-
node=add_node,
378-
graph=graph,
379-
quant_config=core_config.quantization_config,
380-
fqc=graph.fqc,
381-
mixed_precision_enable=core_config.is_mixed_precision_enabled)
382-
383383
original_non_linear_activation_nbits = non_linear_node_cfg_candidate.activation_n_bits
384384
# The non-linear node's output should be float, so we approximate it by using 16bits quantization.
385385
for candidate_qc in non_linear_node.candidates_quantization_cfg:

tests_pytest/common_tests/unit_tests/core/graph/test_node_quantization.py

Lines changed: 81 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -25,37 +25,105 @@
2525
from mct_quantizers import QuantizationMethod
2626

2727

28+
class PreservingNode:
29+
pass
30+
31+
32+
class NoActivationQuantNode:
33+
pass
34+
35+
2836
class TestSetNodeQuantizationConfig:
2937

3038
@staticmethod
31-
def _get_op_config():
39+
def _get_op_config(activation_n_bits,
40+
supported_input_activation_n_bits,
41+
enable_activation_quantization,
42+
quantization_preserving):
3243
aqc = AttributeQuantizationConfig()
3344
return OpQuantizationConfig(default_weight_attr_config=aqc,
3445
attr_weights_configs_mapping={'w': aqc},
3546
activation_quantization_method=QuantizationMethod.POWER_OF_TWO,
36-
activation_n_bits=7,
37-
supported_input_activation_n_bits=7,
38-
enable_activation_quantization=False,
39-
quantization_preserving=True,
47+
activation_n_bits=activation_n_bits,
48+
supported_input_activation_n_bits=supported_input_activation_n_bits,
49+
enable_activation_quantization=enable_activation_quantization,
50+
quantization_preserving=quantization_preserving,
4051
signedness=Signedness.AUTO)
4152

4253
def test_activation_preserving_with_2_inputs(self, fw_info_mock):
4354
""" Tests that . """
4455
n1 = build_node('in1_node')
4556
n2 = build_node('in2_node')
46-
n3 = build_node('qp_node')
47-
n4 = build_node('qp2_node')
48-
graph = Graph('g', input_nodes=[n1, n2], nodes=[n3], output_nodes=[n4],
57+
n3 = build_node('qp_node', layer_class=PreservingNode)
58+
n4 = build_node('qp2_node', layer_class=PreservingNode)
59+
qp3 = build_node('qp3_node', layer_class=PreservingNode)
60+
qp4 = build_node('qp4_node', layer_class=PreservingNode)
61+
graph = Graph('g', input_nodes=[n1, n2], nodes=[n3, qp3], output_nodes=[n4, qp4],
4962
edge_list=[Edge(n1, n3, 0, 0), Edge(n2, n3, 0, 0),
50-
Edge(n3, n4, 0, 0)])
63+
Edge(n3, n4, 0, 0),
64+
Edge(n1, qp3, 0, 0), Edge(qp3, qp4, 0, 0)])
65+
q_op_config_kwargs = {"activation_n_bits": 7, "supported_input_activation_n_bits": 7,
66+
"enable_activation_quantization": True, "quantization_preserving": False}
67+
qp_op_config_kwargs = {"activation_n_bits": 7, "supported_input_activation_n_bits": 7,
68+
"enable_activation_quantization": False, "quantization_preserving": True}
69+
_filters = {DummyLayer: QuantizationConfigOptions(quantization_configurations=[self._get_op_config(**q_op_config_kwargs)]),
70+
PreservingNode: QuantizationConfigOptions(quantization_configurations=[self._get_op_config(**qp_op_config_kwargs)])}
71+
fqc = Mock(filterlayer2qco=_filters, layer2qco=_filters)
5172

52-
fqc = Mock(filterlayer2qco={DummyLayer: QuantizationConfigOptions(quantization_configurations=[self._get_op_config()])},
53-
layer2qco={DummyLayer: QuantizationConfigOptions(quantization_configurations=[self._get_op_config()])})
5473
fw_info_mock = Mock(spec=FrameworkInfo, kernel_channels_mapping={DummyLayer: 0},
5574
activation_quantizer_mapping={QuantizationMethod.POWER_OF_TWO: lambda x: 0},
5675
get_kernel_op_attributes=lambda x: [None])
57-
set_quantization_configs_to_node(n3, graph, QuantizationConfig(), fw_info_mock, fqc)
58-
set_quantization_configs_to_node(n4, graph, QuantizationConfig(), fw_info_mock, fqc)
76+
qc = QuantizationConfig()
77+
for n in graph.get_topo_sorted_nodes():
78+
set_quantization_configs_to_node(n, graph, qc, fw_info_mock, fqc)
5979
assert not n3.is_quantization_preserving() and not n3.is_activation_quantization_enabled()
6080
assert not n4.is_quantization_preserving() and not n4.is_activation_quantization_enabled()
81+
assert qp3.is_quantization_preserving()
82+
assert qp4.is_quantization_preserving()
83+
84+
def test_node_quantization_by_next_nodes(self, fw_info_mock):
85+
"""
86+
Test that node quantization n_bits is unaffected by preserving next node and not-enabled quantization next node.
87+
"""
88+
first_node = build_node('first_node')
89+
preserving_node = build_node('preserving_node', layer_class=PreservingNode)
90+
no_quant_node = build_node('no_enabled_quant_node', layer_class=NoActivationQuantNode)
91+
graph = Graph('g', input_nodes=[first_node], nodes=[preserving_node], output_nodes=[no_quant_node],
92+
edge_list=[Edge(first_node, preserving_node, 0, 0),
93+
Edge(preserving_node, no_quant_node, 0, 0)])
94+
95+
first_node_config_kwargs = {"activation_n_bits": 16,
96+
"supported_input_activation_n_bits": [8, 16],
97+
"enable_activation_quantization": True,
98+
"quantization_preserving": False}
99+
100+
preserving_node_config_kwargs = {"activation_n_bits": 8,
101+
"supported_input_activation_n_bits": [8, 16],
102+
"enable_activation_quantization": False,
103+
"quantization_preserving": True}
104+
105+
no_quant_node_config_kwargs = {"activation_n_bits": 8,
106+
"supported_input_activation_n_bits": [8],
107+
"enable_activation_quantization": False,
108+
"quantization_preserving": False}
109+
_filters = {
110+
DummyLayer: QuantizationConfigOptions(quantization_configurations=[self._get_op_config(**first_node_config_kwargs)]),
111+
PreservingNode: QuantizationConfigOptions(quantization_configurations=[self._get_op_config(**preserving_node_config_kwargs)]),
112+
NoActivationQuantNode: QuantizationConfigOptions(quantization_configurations=[self._get_op_config(**no_quant_node_config_kwargs)]),
113+
}
114+
115+
fqc = Mock(filterlayer2qco=_filters, layer2qco=_filters)
116+
fw_info_mock = Mock(spec=FrameworkInfo, kernel_channels_mapping={DummyLayer: 0},
117+
activation_quantizer_mapping={QuantizationMethod.POWER_OF_TWO: lambda x: 0},
118+
get_kernel_op_attributes=lambda x: [None])
119+
quantization_config = QuantizationConfig()
120+
set_quantization_configs_to_node(first_node, graph, quantization_config, fw_info_mock, fqc)
121+
set_quantization_configs_to_node(preserving_node, graph, quantization_config, fw_info_mock, fqc)
122+
set_quantization_configs_to_node(no_quant_node, graph, quantization_config, fw_info_mock, fqc)
123+
124+
assert not first_node.is_quantization_preserving() and first_node.is_activation_quantization_enabled()
125+
assert preserving_node.is_quantization_preserving() and not preserving_node.is_activation_quantization_enabled()
126+
assert not no_quant_node.is_quantization_preserving() and not no_quant_node.is_activation_quantization_enabled()
61127

128+
# assert that first_node n_bits is 16, and isn't affected by its next nodes which supports 8 n_bits only
129+
assert first_node.candidates_quantization_cfg[0].activation_quantization_cfg.activation_n_bits == 16

0 commit comments

Comments
 (0)