Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -119,11 +119,16 @@ def filter_node_qco_by_graph(node: BaseNode,
_next_nodes.extend(graph.get_next_nodes(n))
next_nodes.append(n)

if len(next_nodes):
next_nodes_qc_options = [_node.get_qco(fqc) for _node in next_nodes]
next_nodes_supported_input_bitwidth = min([max_input_activation_n_bits(op_cfg)
if len(next_nodes) == 0:
return _base_config, _node_qc_options
next_nodes_qc_options = [_node.get_qco(fqc) for _node in next_nodes]
all_next_nodes_supported_input_bitwidth = [max_input_activation_n_bits(op_cfg)
for qc_opts in next_nodes_qc_options
for op_cfg in qc_opts.quantization_configurations])
for op_cfg in qc_opts.quantization_configurations
if op_cfg.enable_activation_quantization or op_cfg.quantization_preserving
Copy link
Copy Markdown
Contributor Author

@yarden-yagil-sony yarden-yagil-sony Apr 23, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@ofirgo
here the "if" condition was missing, which led to considering cq of nodes with enable_activation_quantization=False and quantization_preserving=False.
It leds to error when these irrelevant nodes support only 8 and not both 8 and 16 bit, resulting in a node manually set to bithwidth16 but support only 8 bit (since we take the minimum of preserving next nodes supported bitwidth, line 131).

]
if len(all_next_nodes_supported_input_bitwidth):
next_nodes_supported_input_bitwidth = min(all_next_nodes_supported_input_bitwidth)

# Filter node's QC options that match next nodes input bit-width.
_node_qc_options = [_option for _option in _node_qc_options
Expand Down Expand Up @@ -205,7 +210,7 @@ def set_quantization_configs_to_node(node: BaseNode,
# Preserving the quantization of more than 1 previous node is ambiguous, so disable it.
Logger.info(f"Disabling Quantization-Preserving for node {node.name} because it has more than 1 input activations.")
candidate_qc.activation_quantization_cfg.quant_mode = ActivationQuantizationMode.NO_QUANT
elif not prev_nodes[0].is_quantization_preserving() or not prev_nodes[0].is_activation_quantization_enabled():
elif not prev_nodes[0].is_quantization_preserving() and not prev_nodes[0].is_activation_quantization_enabled():
# Preserving the quantization of an unquantized node isn't possible, so disable it.
Logger.info(f"Disabling Quantization-Preserving for node {node.name} because previous node activation quantization is disabled.")
candidate_qc.activation_quantization_cfg.quant_mode = ActivationQuantizationMode.NO_QUANT
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -343,6 +343,13 @@ def shift_negative_function(graph: Graph,
graph.set_out_stats_collector_to_node(add_node, add_node_stats_collector)
graph.shift_stats_collector(add_node, np.array(shift_value))

set_quantization_configs_to_node(fw_info=fw_info,
node=add_node,
graph=graph,
quant_config=core_config.quantization_config,
fqc=graph.fqc,
mixed_precision_enable=core_config.is_mixed_precision_enabled)

if padding is not None:
pad_node = create_pad_node(op2d_node.name,
add_node.name,
Expand Down Expand Up @@ -373,13 +380,6 @@ def shift_negative_function(graph: Graph,

op2d_node.input_shape = pad_node.output_shape

set_quantization_configs_to_node(fw_info=fw_info,
node=add_node,
graph=graph,
quant_config=core_config.quantization_config,
fqc=graph.fqc,
mixed_precision_enable=core_config.is_mixed_precision_enabled)

original_non_linear_activation_nbits = non_linear_node_cfg_candidate.activation_n_bits
# The non-linear node's output should be float, so we approximate it by using 16bits quantization.
for candidate_qc in non_linear_node.candidates_quantization_cfg:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,37 +25,105 @@
from mct_quantizers import QuantizationMethod


class PreservingNode:
pass


class NoActivationQuantNode:
pass


class TestSetNodeQuantizationConfig:

@staticmethod
def _get_op_config():
def _get_op_config(activation_n_bits,
supported_input_activation_n_bits,
enable_activation_quantization,
quantization_preserving):
aqc = AttributeQuantizationConfig()
return OpQuantizationConfig(default_weight_attr_config=aqc,
attr_weights_configs_mapping={'w': aqc},
activation_quantization_method=QuantizationMethod.POWER_OF_TWO,
activation_n_bits=7,
supported_input_activation_n_bits=7,
enable_activation_quantization=False,
quantization_preserving=True,
activation_n_bits=activation_n_bits,
supported_input_activation_n_bits=supported_input_activation_n_bits,
enable_activation_quantization=enable_activation_quantization,
quantization_preserving=quantization_preserving,
signedness=Signedness.AUTO)

def test_activation_preserving_with_2_inputs(self, fw_info_mock):
""" Tests that . """
n1 = build_node('in1_node')
n2 = build_node('in2_node')
n3 = build_node('qp_node')
n4 = build_node('qp2_node')
graph = Graph('g', input_nodes=[n1, n2], nodes=[n3], output_nodes=[n4],
n3 = build_node('qp_node', layer_class=PreservingNode)
n4 = build_node('qp2_node', layer_class=PreservingNode)
qp3 = build_node('qp3_node', layer_class=PreservingNode)
qp4 = build_node('qp4_node', layer_class=PreservingNode)
graph = Graph('g', input_nodes=[n1, n2], nodes=[n3, qp3], output_nodes=[n4, qp4],
edge_list=[Edge(n1, n3, 0, 0), Edge(n2, n3, 0, 0),
Edge(n3, n4, 0, 0)])
Edge(n3, n4, 0, 0),
Edge(n1, qp3, 0, 0), Edge(qp3, qp4, 0, 0)])
q_op_config_kwargs = {"activation_n_bits": 7, "supported_input_activation_n_bits": 7,
"enable_activation_quantization": True, "quantization_preserving": False}
qp_op_config_kwargs = {"activation_n_bits": 7, "supported_input_activation_n_bits": 7,
"enable_activation_quantization": False, "quantization_preserving": True}
_filters = {DummyLayer: QuantizationConfigOptions(quantization_configurations=[self._get_op_config(**q_op_config_kwargs)]),
PreservingNode: QuantizationConfigOptions(quantization_configurations=[self._get_op_config(**qp_op_config_kwargs)])}
fqc = Mock(filterlayer2qco=_filters, layer2qco=_filters)

fqc = Mock(filterlayer2qco={DummyLayer: QuantizationConfigOptions(quantization_configurations=[self._get_op_config()])},
layer2qco={DummyLayer: QuantizationConfigOptions(quantization_configurations=[self._get_op_config()])})
fw_info_mock = Mock(spec=FrameworkInfo, kernel_channels_mapping={DummyLayer: 0},
activation_quantizer_mapping={QuantizationMethod.POWER_OF_TWO: lambda x: 0},
get_kernel_op_attributes=lambda x: [None])
set_quantization_configs_to_node(n3, graph, QuantizationConfig(), fw_info_mock, fqc)
set_quantization_configs_to_node(n4, graph, QuantizationConfig(), fw_info_mock, fqc)
qc = QuantizationConfig()
for n in graph.get_topo_sorted_nodes():
set_quantization_configs_to_node(n, graph, qc, fw_info_mock, fqc)
assert not n3.is_quantization_preserving() and not n3.is_activation_quantization_enabled()
assert not n4.is_quantization_preserving() and not n4.is_activation_quantization_enabled()
assert qp3.is_quantization_preserving()
assert qp4.is_quantization_preserving()

def test_node_quantization_by_next_nodes(self, fw_info_mock):
"""
Test that node quantization n_bits is unaffected by preserving next node and not-enabled quantization next node.
"""
first_node = build_node('first_node')
preserving_node = build_node('preserving_node', layer_class=PreservingNode)
no_quant_node = build_node('no_enabled_quant_node', layer_class=NoActivationQuantNode)
graph = Graph('g', input_nodes=[first_node], nodes=[preserving_node], output_nodes=[no_quant_node],
edge_list=[Edge(first_node, preserving_node, 0, 0),
Edge(preserving_node, no_quant_node, 0, 0)])

first_node_config_kwargs = {"activation_n_bits": 16,
"supported_input_activation_n_bits": [8, 16],
"enable_activation_quantization": True,
"quantization_preserving": False}

preserving_node_config_kwargs = {"activation_n_bits": 8,
"supported_input_activation_n_bits": [8, 16],
"enable_activation_quantization": False,
"quantization_preserving": True}

no_quant_node_config_kwargs = {"activation_n_bits": 8,
"supported_input_activation_n_bits": [8],
"enable_activation_quantization": False,
"quantization_preserving": False}
_filters = {
DummyLayer: QuantizationConfigOptions(quantization_configurations=[self._get_op_config(**first_node_config_kwargs)]),
PreservingNode: QuantizationConfigOptions(quantization_configurations=[self._get_op_config(**preserving_node_config_kwargs)]),
NoActivationQuantNode: QuantizationConfigOptions(quantization_configurations=[self._get_op_config(**no_quant_node_config_kwargs)]),
}

fqc = Mock(filterlayer2qco=_filters, layer2qco=_filters)
fw_info_mock = Mock(spec=FrameworkInfo, kernel_channels_mapping={DummyLayer: 0},
activation_quantizer_mapping={QuantizationMethod.POWER_OF_TWO: lambda x: 0},
get_kernel_op_attributes=lambda x: [None])
quantization_config = QuantizationConfig()
set_quantization_configs_to_node(first_node, graph, quantization_config, fw_info_mock, fqc)
set_quantization_configs_to_node(preserving_node, graph, quantization_config, fw_info_mock, fqc)
set_quantization_configs_to_node(no_quant_node, graph, quantization_config, fw_info_mock, fqc)

assert not first_node.is_quantization_preserving() and first_node.is_activation_quantization_enabled()
assert preserving_node.is_quantization_preserving() and not preserving_node.is_activation_quantization_enabled()
assert not no_quant_node.is_quantization_preserving() and not no_quant_node.is_activation_quantization_enabled()

# assert that first_node n_bits is 16, and isn't affected by its next nodes which supports 8 n_bits only
assert first_node.candidates_quantization_cfg[0].activation_quantization_cfg.activation_n_bits == 16