Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,8 @@
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
from __future__ import annotations

from typing import Any, List, Dict, TYPE_CHECKING
from enum import Enum, auto

Expand All @@ -24,6 +26,7 @@
AttributeQuantizationConfig, OpQuantizationConfig

if TYPE_CHECKING:
from model_compression_toolkit.core.common import BaseNode
from model_compression_toolkit.core.common.graph.base_node import WeightAttrT

##########################################
Expand Down Expand Up @@ -148,7 +151,7 @@ def set_activation_quantization_param(self,
activation_params: Dictionary that contains weight quantization params.

"""
assert self.quant_mode == ActivationQuantizationMode.QUANT
assert self.quant_mode == ActivationQuantizationMode.QUANT or self.quant_mode == ActivationQuantizationMode.FLN_QUANT
for param_name, param_value in activation_params.items():
self.activation_quantization_params[param_name] = param_value

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -128,7 +128,7 @@ def calculate_quantization_params(graph: Graph,
attr_cfg.weights_channels_axis = ChannelAxisMapping(output_channels_axis, attr_cfg.weights_channels_axis.input)
attr_cfg.set_weights_quantization_param(weights_params)

if n.is_activation_quantization_enabled():
if n.is_activation_quantization_enabled() or n.is_fln_quantization():
# If node's activations should be quantized as well, we compute its activation quantization parameters
activation_params = compute_activation_qparams(
activation_quant_cfg=candidate_qc.activation_quantization_cfg, node_prior_info=n.prior_info,
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,138 @@
# Copyright 2025 Sony Semiconductor Israel, Inc. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
import pytest
import numpy as np
import torch

from unittest.mock import Mock
from typing import List
from model_compression_toolkit.core.common import Graph, BaseNode

from model_compression_toolkit.core.common.quantization.quantization_params_generation.qparams_computation import \
calculate_quantization_params
from model_compression_toolkit.core.common.quantization.candidate_node_quantization_config import \
CandidateNodeQuantizationConfig, NodeQuantizationConfig
from model_compression_toolkit.core.common.quantization.node_quantization_config import \
ActivationQuantizationMode, NodeActivationQuantizationConfig
from model_compression_toolkit.target_platform_capabilities import OpQuantizationConfig
from model_compression_toolkit.core import QuantizationConfig
from model_compression_toolkit.target_platform_capabilities.schema.mct_current_schema import Signedness, \
AttributeQuantizationConfig
from model_compression_toolkit.core.pytorch.default_framework_info import PyTorchInfo
from model_compression_toolkit.core.common.framework_info import set_fw_info, get_fw_info
from model_compression_toolkit.core.pytorch.pytorch_implementation import PytorchImplementation
from model_compression_toolkit.core.common.collectors.statistics_collector import StatsCollector
from mct_quantizers import QuantizationMethod


class TestCalculateQuantizationParams:
def build_node(self, name='node', framework_attr={}, layer_class=torch.nn.Conv2d,
qcs: List[CandidateNodeQuantizationConfig] = None):
node = BaseNode(name=name,
framework_attr=framework_attr,
input_shape=(4, 5, 6),
output_shape=(4, 5, 6),
weights={},
layer_class=layer_class,
reuse=False)
if qcs:
assert isinstance(qcs, list)
node.quantization_cfg = NodeQuantizationConfig(base_quantization_cfg=qcs[0],
candidates_quantization_cfg=qcs)
return node

def build_qc(self, q_mode=ActivationQuantizationMode.QUANT):
op_cfg = OpQuantizationConfig(
default_weight_attr_config=AttributeQuantizationConfig(),
attr_weights_configs_mapping={},
activation_quantization_method=QuantizationMethod.POWER_OF_TWO,
activation_n_bits=8,
enable_activation_quantization=True,
quantization_preserving=False,
supported_input_activation_n_bits=8,
signedness=Signedness.AUTO
)
a_qcfg = NodeActivationQuantizationConfig(op_cfg=op_cfg)
a_qcfg.set_qc(QuantizationConfig())
a_qcfg.quant_mode = q_mode

w_qcfg = Mock()
w_qcfg._validate_consistent_weights_quant_mode.return_value = True
w_qcfg.get_all_weight_attrs_configs.return_value = {}

qc = CandidateNodeQuantizationConfig(activation_quantization_cfg=a_qcfg,
weights_quantization_cfg=w_qcfg)
return qc

def get_test_graph(self, node_name, q_mode, data):
set_fw_info(PyTorchInfo)

node = self.build_node(node_name, framework_attr={'in_channels': 3, 'out_channels': 3, 'kernel_size': 3},
qcs=[self.build_qc(q_mode=q_mode)])

graph = Graph('graph_name', input_nodes=[node],
nodes=[node],
output_nodes=[node],
edge_list=[]
)
fw_impl = PytorchImplementation()

graph.node_to_out_stats_collector = dict()
for n in graph.nodes():
n.prior_info = fw_impl.get_node_prior_info(node=n, graph=graph)

graph.node_to_out_stats_collector[n] = StatsCollector(init_min_value=0.0, init_max_value=1.0, out_channel_axis=get_fw_info().out_channel_axis_mapping.get(n.type))
graph.node_to_out_stats_collector[n].hc._n_bins = 3
graph.node_to_out_stats_collector[n].hc._bins = np.array(data)
graph.node_to_out_stats_collector[n].hc._counts = np.array([1, 1])

return graph, fw_impl

### test pattern for ActivationQuantizationMode
# node_name, q_mode, input data, expected value
test_input_0 = ['node_quant', ActivationQuantizationMode.QUANT, [0.4, 0.8, 1.2], [1.0, False]]
test_input_1 = ['node_fln_quant', ActivationQuantizationMode.FLN_QUANT, [0.7, 1.4, 2.1], [2.0, False]]
test_input_2 = ['node_fln_no_quant', ActivationQuantizationMode.FLN_NO_QUANT, [0.7, 1.4, 2.1], [None, None]]
test_input_3 = ['node_no_quant', ActivationQuantizationMode.NO_QUANT, [0.7, 1.4, 2.1], [None, None]]
test_input_4 = ['node_preserve_quant', ActivationQuantizationMode.PRESERVE_QUANT, [0.7, 1.4, 2.1], [None, None]]
@pytest.mark.parametrize("inputs", [
test_input_0,
test_input_1,
test_input_2,
test_input_3,
test_input_4
])
def test_calculate_quantization_params(self, inputs):
expects = inputs[3]

graph, fw_impl = self.get_test_graph(inputs[0], inputs[1], inputs[2])

calculate_quantization_params(graph, fw_impl, None)

for node in graph.nodes:
for candidate_qc in node.candidates_quantization_cfg:
assert type(candidate_qc.activation_quantization_cfg.activation_quantization_params) == dict
if expects[0] is not None:
### QUANT or FLN_QUANT
assert 'threshold' in candidate_qc.activation_quantization_cfg.activation_quantization_params.keys()
assert 'is_signed' in candidate_qc.activation_quantization_cfg.activation_quantization_params.keys()

threshold = candidate_qc.activation_quantization_cfg.activation_quantization_params['threshold']
is_signed = candidate_qc.activation_quantization_cfg.activation_quantization_params['is_signed']
assert threshold == expects[0]
assert is_signed == expects[1]
else:
assert 'threshold' not in candidate_qc.activation_quantization_cfg.activation_quantization_params.keys()
assert 'is_signed' not in candidate_qc.activation_quantization_cfg.activation_quantization_params.keys()