Skip to content
Merged
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,6 @@
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
from __future__ import annotations

from typing import Any, List, Dict, TYPE_CHECKING
from enum import Enum, auto

Expand All @@ -26,7 +24,6 @@
AttributeQuantizationConfig, OpQuantizationConfig

if TYPE_CHECKING:
from model_compression_toolkit.core.common import BaseNode
from model_compression_toolkit.core.common.graph.base_node import WeightAttrT

##########################################
Expand Down Expand Up @@ -143,17 +140,15 @@ def fln_quantization(self):
return self.quant_mode == ActivationQuantizationMode.FLN_QUANT

def set_activation_quantization_param(self,
activation_params: dict,
node: BaseNode):
activation_params: dict):
"""
Set a quantization parameter for the node's activation.

Args:
activation_params: Dictionary that contains weight quantization params.
node: node in a graph that represents the model.

"""
assert node.is_activation_quantization_enabled() or node.is_fln_quantization()
assert self.quant_mode == ActivationQuantizationMode.QUANT or self.quant_mode == ActivationQuantizationMode.FLN_QUANT
for param_name, param_value in activation_params.items():
self.activation_quantization_params[param_name] = param_value

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -134,4 +134,4 @@ def calculate_quantization_params(graph: Graph,
activation_quant_cfg=candidate_qc.activation_quantization_cfg, node_prior_info=n.prior_info,
out_stats_container=graph.get_out_stats_collector(n))
# Create a NodeQuantizationConfig containing all quantization params and attach it to the node
candidate_qc.activation_quantization_cfg.set_activation_quantization_param(activation_params, n)
candidate_qc.activation_quantization_cfg.set_activation_quantization_param(activation_params)
Comment thread
kkawa14 marked this conversation as resolved.
Outdated
Original file line number Diff line number Diff line change
Expand Up @@ -15,218 +15,124 @@
import pytest
import numpy as np
import torch

from unittest.mock import Mock
from typing import List
from model_compression_toolkit.core.common import Graph, BaseNode

from model_compression_toolkit.core.common.quantization.quantization_params_generation.qparams_computation import \
calculate_quantization_params
from model_compression_toolkit.core.common.quantization.candidate_node_quantization_config import \
CandidateNodeQuantizationConfig, NodeQuantizationConfig
from model_compression_toolkit.core.common.quantization.node_quantization_config import \
ActivationQuantizationMode, NodeActivationQuantizationConfig, NodeWeightsQuantizationConfig
ActivationQuantizationMode, NodeActivationQuantizationConfig
from model_compression_toolkit.target_platform_capabilities import OpQuantizationConfig
from model_compression_toolkit.core import QuantizationConfig, QuantizationErrorMethod
from model_compression_toolkit.core.common.hessian.hessian_info_service import HessianInfoService
from model_compression_toolkit.target_platform_capabilities.targetplatform2framework.attach2pytorch import \
AttachTpcToPytorch
import model_compression_toolkit.target_platform_capabilities.schema.mct_current_schema as schema
from model_compression_toolkit.core import QuantizationConfig
from model_compression_toolkit.target_platform_capabilities.schema.mct_current_schema import Signedness, \
AttributeQuantizationConfig
from model_compression_toolkit.core.pytorch.default_framework_info import PyTorchInfo
from model_compression_toolkit.core.common.framework_info import set_fw_info, get_fw_info

from model_compression_toolkit.core.pytorch.pytorch_implementation import PytorchImplementation
from model_compression_toolkit.core.common.collectors.statistics_collector import StatsCollector
from model_compression_toolkit.target_platform_capabilities.constants import KERNEL_ATTR, WEIGHTS_N_BITS
from mct_quantizers import QuantizationMethod

from model_compression_toolkit.core.common.framework_info import ChannelAxisMapping

class TestCalculateQuantizationParams:
def get_op_qco(self):
# define a default quantization config for all non-specified weights attributes.
default_weight_attr_config = AttributeQuantizationConfig()

# define a quantization config to quantize the kernel (for layers where there is a kernel attribute).
kernel_base_config = AttributeQuantizationConfig(
weights_n_bits=8,
weights_per_channel_threshold=True,
enable_weights_quantization=True)

base_cfg = schema.OpQuantizationConfig(
default_weight_attr_config=default_weight_attr_config,
attr_weights_configs_mapping={KERNEL_ATTR: kernel_base_config},
activation_quantization_method=QuantizationMethod.POWER_OF_TWO,
activation_n_bits=8,
supported_input_activation_n_bits=8,
enable_activation_quantization=True,
quantization_preserving=False,
signedness=Signedness.AUTO)

default_config = schema.OpQuantizationConfig(
default_weight_attr_config=default_weight_attr_config,
def build_node(self, name='node', framework_attr={}, layer_class=torch.nn.Conv2d,
qcs: List[CandidateNodeQuantizationConfig] = None):
node = BaseNode(name=name,
framework_attr=framework_attr,
input_shape=(4, 5, 6),
output_shape=(4, 5, 6),
weights={},
layer_class=layer_class,
reuse=False)
if qcs:
assert isinstance(qcs, list)
node.quantization_cfg = NodeQuantizationConfig(base_quantization_cfg=qcs[0],
candidates_quantization_cfg=qcs)
return node
Comment thread
kkawa14 marked this conversation as resolved.
Outdated

def build_qc(self, q_mode=ActivationQuantizationMode.QUANT):
op_cfg = OpQuantizationConfig(
default_weight_attr_config=AttributeQuantizationConfig(),
attr_weights_configs_mapping={},
activation_quantization_method=QuantizationMethod.POWER_OF_TWO,
activation_n_bits=8,
supported_input_activation_n_bits=8,
enable_activation_quantization=True,
quantization_preserving=False,
signedness=Signedness.AUTO
)

mx_cfg_list = [base_cfg]
for n in [8, 4, 2]:
mx_cfg_list.append(base_cfg.clone_and_edit(attr_to_edit={KERNEL_ATTR: {WEIGHTS_N_BITS: n}}))

return base_cfg, mx_cfg_list, default_config

def generate_tpc_local(self, default_config, base_config, mixed_precision_cfg_list):
default_configuration_options = schema.QuantizationConfigOptions(
quantization_configurations=tuple([default_config]))
mixed_precision_configuration_options = schema.QuantizationConfigOptions(
quantization_configurations=tuple(mixed_precision_cfg_list),
base_config=base_config)

operator_set = []

conv = schema.OperatorsSet(name=schema.OperatorSetNames.CONV, qc_options=mixed_precision_configuration_options)
relu = schema.OperatorsSet(name=schema.OperatorSetNames.RELU)
operator_set.extend([conv, relu])

generated_tpc = schema.TargetPlatformCapabilities(
default_qco=default_configuration_options,
operator_set=tuple(operator_set))

return generated_tpc

def get_tpc(self):
base_cfg, mx_cfg_list, default_config = self.get_op_qco()
tpc = self.generate_tpc_local(default_config, base_cfg, mx_cfg_list)
return tpc

def get_float_model(self):
class BaseModel(torch.nn.Module):
def __init__(self):
super().__init__()
self.conv1 = torch.nn.Conv2d(in_channels=3, out_channels=3, kernel_size=3)
self.conv2 = torch.nn.Conv2d(in_channels=3, out_channels=3, kernel_size=3)
self.conv3 = torch.nn.Conv2d(in_channels=3, out_channels=3, kernel_size=3)
self.relu = torch.nn.ReLU()

def forward(self, x):
x = self.conv1(x)
x = self.conv2(x)
x = self.relu(x)
x = self.conv3(x)
return x

return BaseModel()

def _create_weights_attr_quantization_config(self, weights_n_bits: int) -> AttributeQuantizationConfig:
weights_attr_config = AttributeQuantizationConfig(weights_n_bits=weights_n_bits)
return weights_attr_config

def _create_node_weights_op_cfg(self,
def_weight_attr_config: AttributeQuantizationConfig) -> OpQuantizationConfig:
# define a quantization config to quantize the kernel (for layers where there is a kernel attribute).
kernel_base_config = AttributeQuantizationConfig(
weights_quantization_method=QuantizationMethod.SYMMETRIC,
enable_weights_quantization=False,
weights_n_bits=8)

# define a quantization config to quantize the bias (for layers where there is a bias attribute).
bias_config = AttributeQuantizationConfig()

attr_weights_configs_mapping = {'weight': kernel_base_config, 'bias': bias_config}
op_cfg = OpQuantizationConfig(
default_weight_attr_config=def_weight_attr_config,
attr_weights_configs_mapping=attr_weights_configs_mapping,
activation_quantization_method=QuantizationMethod.POWER_OF_TWO,
activation_n_bits=8,
supported_input_activation_n_bits=8,
enable_activation_quantization=True,
quantization_preserving=False,
signedness=Signedness.AUTO
)
a_qcfg = NodeActivationQuantizationConfig(op_cfg=op_cfg)
a_qcfg.set_qc(QuantizationConfig())
a_qcfg.quant_mode = q_mode

return op_cfg
w_qcfg = Mock()
w_qcfg._validate_consistent_weights_quant_mode.return_value = True
w_qcfg.get_all_weight_attrs_configs.return_value = {}

def get_test_graph(self, qem: QuantizationErrorMethod):
float_model = self.get_float_model()
set_fw_info(PyTorchInfo)
qc = CandidateNodeQuantizationConfig(activation_quantization_cfg=a_qcfg,
weights_quantization_cfg=w_qcfg)
return qc

fw_impl = PytorchImplementation()
graph = fw_impl.model_reader(float_model,
self.representative_data_gen)

quantization_config = QuantizationConfig(weights_error_method=qem)
def get_test_graph(self, node_name, q_mode, data):
set_fw_info(PyTorchInfo)

tpc = self.get_tpc()
attach2pytorch = AttachTpcToPytorch()
fqc = attach2pytorch.attach(
tpc, quantization_config.custom_tpc_opset_to_layer)
graph.set_fqc(fqc)
node = self.build_node(node_name, framework_attr={'in_channels': 3, 'out_channels': 3, 'kernel_size': 3},
qcs=[self.build_qc(q_mode=q_mode)])

def_weight_attr_config = self._create_weights_attr_quantization_config(weights_n_bits=8)
op_cfg = self._create_node_weights_op_cfg(def_weight_attr_config=def_weight_attr_config)
graph = Graph('graph_name', input_nodes=[node],
nodes=[node],
output_nodes=[node],
edge_list=[]
)
fw_impl = PytorchImplementation()

graph.node_to_out_stats_collector = dict()
for id, n in enumerate(graph.nodes):
for n in graph.nodes():
n.prior_info = fw_impl.get_node_prior_info(node=n, graph=graph)

Comment thread
kkawa14 marked this conversation as resolved.
Outdated
activation_quantization_cfg = NodeActivationQuantizationConfig(op_cfg=op_cfg)
activation_quantization_cfg.set_qc(quantization_config)
weights_quantization_cfg = NodeWeightsQuantizationConfig(op_cfg=op_cfg,
weights_channels_axis=ChannelAxisMapping(0, 1),
node_attrs_list=['weight', 'bias'])
candidate_qc_a = CandidateNodeQuantizationConfig(
activation_quantization_cfg=activation_quantization_cfg,
weights_quantization_cfg=weights_quantization_cfg)
if n.name in ['conv3']:
candidate_qc_a.activation_quantization_cfg.quant_mode = ActivationQuantizationMode.FLN_QUANT
else:
candidate_qc_a.activation_quantization_cfg.quant_mode = ActivationQuantizationMode.QUANT
n.quantization_cfg = NodeQuantizationConfig(base_quantization_cfg=candidate_qc_a, candidates_quantization_cfg=[candidate_qc_a, candidate_qc_a])

graph.node_to_out_stats_collector[n] = StatsCollector(init_min_value=0.0, init_max_value=1.0, out_channel_axis=get_fw_info().out_channel_axis_mapping.get(n.type))
graph.node_to_out_stats_collector[n].hc._n_bins = 3
if n.name in ['conv1']:
graph.node_to_out_stats_collector[n].hc._bins = np.array([0.4, 0.8, 1.2])
elif n.name in ['conv2']:
graph.node_to_out_stats_collector[n].hc._bins = np.array([0.7, 1.4, 2.1])
elif n.name in ['conv3']:
graph.node_to_out_stats_collector[n].hc._bins = np.array([-32, -24, -1])
elif n.name in ['relu']:
graph.node_to_out_stats_collector[n].hc._bins = np.array([2.0, 4.0, 6.0])
else:
graph.node_to_out_stats_collector[n].hc._bins = np.array([0.1, 0.2, 0.3])
graph.node_to_out_stats_collector[n].hc._bins = np.array(data)
graph.node_to_out_stats_collector[n].hc._counts = np.array([1, 1])

return graph, fw_impl
Comment thread
kkawa14 marked this conversation as resolved.
Outdated

def representative_data_gen(self, shape=(3, 8, 8), num_inputs=1, batch_size=2, num_iter=10):
for _ in range(num_iter):
yield [torch.randn(batch_size, *shape)] * num_inputs

def test_calculate_quantization_params(self):
graph, fw_impl = self.get_test_graph(QuantizationErrorMethod.MSE)

calculate_quantization_params(graph, fw_impl, self.representative_data_gen)
### test pattern for ActivationQuantizationMode
# node_name, q_mode, input data, expected value
test_input_0 = ['node_quant', ActivationQuantizationMode.QUANT, [0.4, 0.8, 1.2], [1.0, False]]
test_input_1 = ['node_fln_quant', ActivationQuantizationMode.FLN_QUANT, [0.7, 1.4, 2.1], [2.0, False]]
test_input_2 = ['node_fln_no_quant', ActivationQuantizationMode.FLN_NO_QUANT, [0.7, 1.4, 2.1], [None, None]]
test_input_3 = ['node_no_quant', ActivationQuantizationMode.NO_QUANT, [0.7, 1.4, 2.1], [None, None]]
test_input_4 = ['node_preserve_quant', ActivationQuantizationMode.PRESERVE_QUANT, [0.7, 1.4, 2.1], [None, None]]
@pytest.mark.parametrize("inputs", [
test_input_0,
test_input_1,
test_input_2,
test_input_3,
test_input_4
])
def test_calculate_quantization_params(self, inputs):
Comment thread
kkawa14 marked this conversation as resolved.
Outdated
expects = inputs[3]

graph, fw_impl = self.get_test_graph(inputs[0], inputs[1], inputs[2])

calculate_quantization_params(graph, fw_impl, None)
Comment thread
kkawa14 marked this conversation as resolved.
Outdated
Comment thread
kkawa14 marked this conversation as resolved.
Outdated

for node in graph.nodes:
for candidate_qc in node.candidates_quantization_cfg:
assert type(candidate_qc.activation_quantization_cfg.activation_quantization_params) == dict
Comment thread
kkawa14 marked this conversation as resolved.
Outdated
assert 'threshold' in candidate_qc.activation_quantization_cfg.activation_quantization_params.keys()
assert 'is_signed' in candidate_qc.activation_quantization_cfg.activation_quantization_params.keys()

threshold = candidate_qc.activation_quantization_cfg.activation_quantization_params['threshold']
is_signed = candidate_qc.activation_quantization_cfg.activation_quantization_params['is_signed']
if node.name in 'conv1':
assert threshold == 1.0
assert is_signed == False
elif node.name in 'conv2':
assert threshold == 2.0
assert is_signed == False
elif node.name in 'conv3':
assert threshold == 64.0
assert is_signed == True
elif node.name in 'relu':
assert threshold == 16.0
assert is_signed == False
if expects[0] is not None:
### QUANT or FLN_QUANT
assert 'threshold' in candidate_qc.activation_quantization_cfg.activation_quantization_params.keys()
assert 'is_signed' in candidate_qc.activation_quantization_cfg.activation_quantization_params.keys()

threshold = candidate_qc.activation_quantization_cfg.activation_quantization_params['threshold']
is_signed = candidate_qc.activation_quantization_cfg.activation_quantization_params['is_signed']
assert threshold == expects[0]
assert is_signed == expects[1]
else:
assert 'threshold' not in candidate_qc.activation_quantization_cfg.activation_quantization_params.keys()
assert 'is_signed' not in candidate_qc.activation_quantization_cfg.activation_quantization_params.keys()