Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,6 @@
from model_compression_toolkit.core.common.framework_info import ChannelAxisMapping
from model_compression_toolkit.logger import Logger

from model_compression_toolkit.core.common.quantization.quantization_config import QuantizationConfig
from model_compression_toolkit.target_platform_capabilities.constants import POSITIONAL_ATTR
from model_compression_toolkit.target_platform_capabilities.schema.mct_current_schema import \
AttributeQuantizationConfig, OpQuantizationConfig
Expand Down Expand Up @@ -60,12 +59,11 @@ def set_quant_config_attr(self, config_parameter_name: str, config_parameter_val
kwargs: A dictionary with additional key arguments.

"""

if hasattr(self, config_parameter_name):
setattr(self, config_parameter_name, config_parameter_value)
else:
Logger.warning(f"Parameter {config_parameter_name} could not be found in the node quantization config and "
f"was not updated!")
raise AttributeError(
f"Parameter {config_parameter_name} could not be found in the node quantization config.")

def __repr__(self) -> str:
"""
Expand Down Expand Up @@ -98,7 +96,7 @@ def __init__(self, op_cfg: OpQuantizationConfig):
self.signedness = op_cfg.signedness

self.activation_quantization_params = {}
# TODO irena: computed by compute_activation_bias_correction. shouldnt really be here
# TODO: computed by compute_activation_bias_correction. Probably shouldnt be here.
self.activation_bias_correction_term = None

@property
Expand Down Expand Up @@ -140,12 +138,14 @@ def __eq__(self, other: Any) -> bool:

return self.activation_quantization_method == other.activation_quantization_method and \
self.activation_n_bits == other.activation_n_bits and \
self.quant_mode == other.quant_mode
self.quant_mode == other.quant_mode and \
self.signedness == other.signedness

def __hash__(self):
return hash((self.activation_quantization_method,
self.activation_n_bits,
self.quant_mode))
self.quant_mode,
self.signedness))


class WeightsAttrQuantizationConfig:
Expand All @@ -166,16 +166,8 @@ def __init__(self,
self.weights_n_bits = weights_attr_cfg.weights_n_bits
self.weights_per_channel_threshold = weights_attr_cfg.weights_per_channel_threshold
self.enable_weights_quantization = weights_attr_cfg.enable_weights_quantization
self.weights_quantization_params = {}

# TODO irena remove along with set_qc. Keeping for eq and hash to work without set_qc being called
self.weights_error_method = None
self.l_p_value = None

def set_qc(self, qc: QuantizationConfig):
# TODO irena: temporary keep the fields to not break everything at once.
self.weights_error_method = qc.weights_error_method
self.l_p_value = qc.l_p_value
self.weights_quantization_params = {}

def set_weights_quantization_param(self,
weights_params: dict):
Expand Down Expand Up @@ -207,18 +199,14 @@ def __eq__(self, other: Any) -> bool:
self.weights_quantization_method == other.weights_quantization_method and \
self.weights_n_bits == other.weights_n_bits and \
self.weights_per_channel_threshold == other.weights_per_channel_threshold and \
self.enable_weights_quantization == other.enable_weights_quantization and \
self.weights_error_method == other.weights_error_method and \
self.l_p_value == other.l_p_value
self.enable_weights_quantization == other.enable_weights_quantization

def __hash__(self):
return hash((self.weights_channels_axis,
self.weights_error_method,
self.weights_quantization_method,
self.weights_n_bits,
self.weights_per_channel_threshold,
self.enable_weights_quantization,
self.l_p_value))
self.enable_weights_quantization))


class NodeWeightsQuantizationConfig(BaseNodeQuantizationConfig):
Expand Down Expand Up @@ -285,16 +273,14 @@ def __init__(self,

self.attributes_config_mapping[attr] = WeightsAttrQuantizationConfig(weights_attr_cfg=attr_cfg,
weights_channels_axis=weights_channels_axis)
# TODO irena remove along with set_qc. Keeping for eq and hash to work without set_qc being called
self.min_threshold = None
# TODO this is set by batch norm reconstruction substitution when folded batch norms are added back, to mark
# the nodes that the correction should be applied to (for some nodes it gets disabled) and BNs removed.
# The actual correction is only computed when it's applied in ptq, so it seems that both substitutions could
# be unified, and no info need to pass between.
self.weights_second_moment_correction = None
self.weights_bias_correction = None

def set_qc(self, qc: QuantizationConfig):
# TODO irena: temporary keep the fields to not break everything at once.
self.min_threshold = qc.min_threshold
self.weights_second_moment_correction = qc.weights_second_moment_correction
self.weights_bias_correction = qc.weights_bias_correction
# TODO: computed corrected bias is injected to the node config. Probably shouldn't be here. Also it can be
# computed on the final config, instead of all candidates and then there is no need to save it at all.
self.bias_corrected = None

def get_attr_config(self, attr_name: 'WeightAttrT') -> WeightsAttrQuantizationConfig:
"""
Expand Down Expand Up @@ -431,8 +417,8 @@ def set_quant_config_attr(self, config_parameter_name: str, config_parameter_val
if hasattr(attr_cfg, config_parameter_name):
setattr(attr_cfg, config_parameter_name, config_parameter_value)
else:
Logger.warning(f"Parameter {config_parameter_name} could not be found in the node quantization config of "
f"weights attribute {attr_name} and was not updated!")
raise AttributeError(f"Parameter {config_parameter_name} could not be found in the node quantization config of "
f"weights attribute {attr_name}.")
else: # pragma: no cover
Logger.critical(f"Weights attribute {attr_name} could not be found to set parameter {config_parameter_name}.")

Expand All @@ -449,10 +435,7 @@ def __eq__(self, other: Any) -> bool:
if not isinstance(other, NodeWeightsQuantizationConfig):
return False # pragma: no cover

return self.min_threshold == other.min_threshold and \
self.simd_size == other.simd_size and \
self.weights_second_moment_correction == other.weights_second_moment_correction and \
self.weights_bias_correction == other.weights_bias_correction and \
return self.simd_size == other.simd_size and \
self.attributes_config_mapping.keys() == other.attributes_config_mapping.keys() and \
all([self.attributes_config_mapping[k] == other.attributes_config_mapping[k]
for k in self.attributes_config_mapping.keys()]) and \
Expand All @@ -461,9 +444,6 @@ def __eq__(self, other: Any) -> bool:
for k in self.pos_attributes_config_mapping.keys()])

def __hash__(self):
return hash((self.min_threshold,
self.simd_size,
self.weights_second_moment_correction,
self.weights_bias_correction,
return hash((self.simd_size,
frozenset(self.attributes_config_mapping),
frozenset(self.pos_attributes_config_mapping)))
Original file line number Diff line number Diff line change
Expand Up @@ -90,7 +90,6 @@ class QuantizationConfig:
shift_negative_activation_correction: bool = True
activation_channel_equalization: bool = False
z_threshold: float = math.inf
min_threshold: float = MIN_THRESHOLD
l_p_value: int = 2
linear_collapsing: bool = True
residual_collapsing: bool = True
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
from mct_quantizers import QuantizationMethod

import model_compression_toolkit.core.common.quantization.quantization_params_generation as qpg
from model_compression_toolkit.constants import MIN_THRESHOLD
from model_compression_toolkit.target_platform_capabilities.schema.mct_current_schema import Signedness
from model_compression_toolkit.core.common.collectors.statistics_collector import BaseStatsCollector
from model_compression_toolkit.core.common.node_prior_info import NodePriorInfo
Expand Down Expand Up @@ -64,7 +65,7 @@ def compute_activation_qparams(quant_cfg: QuantizationConfig,
node_activation_quant_cfg.activation_n_bits,
min_value,
max_value,
min_threshold=quant_cfg.min_threshold,
min_threshold=MIN_THRESHOLD,
quant_error_method=quant_cfg.activation_error_method,
is_signed=signed
)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -31,28 +31,6 @@
from model_compression_toolkit.logger import Logger


def _collect_nodes_for_hmse(nodes_list: List[BaseNode], graph: Graph) -> List[BaseNode]:
"""
Collects nodes that are compatiable for parameters selection search using HMSE,
that is, have a kernel attribute that is configured for HMSE error method.

Args:
nodes_list: A list of nodes to search quantization parameters for.
graph: Graph to compute its nodes' quantization parameters..

Returns: A (possibly empty) list of nodes.

"""
hmse_nodes = []
for n in nodes_list:
if n.kernel_attr is not None and n.is_weights_quantization_enabled(n.kernel_attr) and \
all([c.weights_quantization_cfg.get_attr_config(n.kernel_attr).weights_error_method ==
QuantizationErrorMethod.HMSE for c in n.candidates_quantization_cfg]):
hmse_nodes.append(n)

return hmse_nodes


def calculate_quantization_params(graph: Graph,
quant_cfg: QuantizationConfig,
fw_impl: FrameworkImplementation,
Expand Down Expand Up @@ -87,44 +65,41 @@ def calculate_quantization_params(graph: Graph,
# Collecting nodes that are configured to search weights quantization parameters using HMSE optimization
# and computing required Hessian information to be used for HMSE parameters selection.
# The Hessian scores are computed and stored in the hessian_info_service object.
nodes_for_hmse = _collect_nodes_for_hmse(nodes_list, graph)
if len(nodes_for_hmse) > 0:
dataloader = fw_impl.convert_data_gen_to_dataloader(repr_data_gen_fn, batch_size=1)
request = HessianScoresRequest(mode=HessianMode.WEIGHTS,
granularity=HessianScoresGranularity.PER_ELEMENT,
data_loader=dataloader,
n_samples=num_hessian_samples,
target_nodes=nodes_for_hmse)
hessian_info_service.fetch_hessian(request)
if quant_cfg.weights_error_method == QuantizationErrorMethod.HMSE:
nodes_for_hmse = [n for n in nodes_list if n.kernel_attr and n.is_weights_quantization_enabled(n.kernel_attr)]
if nodes_for_hmse:
dataloader = fw_impl.convert_data_gen_to_dataloader(repr_data_gen_fn, batch_size=1)
request = HessianScoresRequest(mode=HessianMode.WEIGHTS,
granularity=HessianScoresGranularity.PER_ELEMENT,
data_loader=dataloader,
n_samples=num_hessian_samples,
target_nodes=nodes_for_hmse)
hessian_info_service.fetch_hessian(request)

for n in tqdm(nodes_list, "Calculating quantization parameters"): # iterate only nodes that we should compute their thresholds
for candidate_qc in n.candidates_quantization_cfg:
for attr in n.get_node_weights_attributes():
if n.is_weights_quantization_enabled(attr):
# If the node's weights attribute should be quantized, we compute its quantization parameters
attr_cfg = candidate_qc.weights_quantization_cfg.get_attr_config(attr)
channels_axis = attr_cfg.weights_channels_axis
if channels_axis is not None:
output_channels_axis = channels_axis[0]
else:
output_channels_axis = None

mod_attr_cfg = attr_cfg
output_channels_axis = attr_cfg.weights_channels_axis.output

if attr_cfg.weights_error_method == QuantizationErrorMethod.HMSE:
weights_error_method = quant_cfg.weights_error_method
if weights_error_method == QuantizationErrorMethod.HMSE:
# Although we collected nodes for HMSE before running the loop, we keep this verification to
# notify the user in case of HMSE configured for node that is not compatible for this method
if n.kernel_attr is None or n.kernel_attr not in attr:
Logger.warning(f"The HMSE error method for parameters selection is only supported for "
f"kernel weights attributes. Running parameters selection for attribute "
f"'{attr}' in node '{n.name}' with the default MSE error method instead.")
mod_attr_cfg = copy.deepcopy(attr_cfg)
mod_attr_cfg.weights_error_method = QuantizationErrorMethod.MSE
weights_error_method = QuantizationErrorMethod.MSE

min_threshold = candidate_qc.weights_quantization_cfg.min_threshold
weights_params, output_channels_axis = compute_weights_qparams(n.get_weights_by_keys(attr),
mod_attr_cfg, output_channels_axis,
min_threshold=min_threshold, node=n,
attr_cfg,
weights_error_method,
quant_cfg.l_p_value,
output_channels_axis,
node=n,
hessian_info_service=hessian_info_service,
num_hessian_samples=num_hessian_samples)
attr_cfg.weights_channels_axis = ChannelAxisMapping(output_channels_axis, attr_cfg.weights_channels_axis.input)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,8 @@
import numpy as np
from mct_quantizers import QuantizationMethod

from model_compression_toolkit.constants import NUM_QPARAM_HESSIAN_SAMPLES
from model_compression_toolkit.constants import NUM_QPARAM_HESSIAN_SAMPLES, MIN_THRESHOLD
from model_compression_toolkit.core import QuantizationErrorMethod
from model_compression_toolkit.core.common.hessian import HessianInfoService
from model_compression_toolkit.core.common.quantization.quantization_params_generation import \
power_of_two_selection_tensor, lut_kmeans_tensor, symmetric_selection_tensor, uniform_selection_tensor
Expand All @@ -28,10 +29,12 @@
from model_compression_toolkit.core.common.quantization.node_quantization_config import WeightsAttrQuantizationConfig


def compute_weights_qparams(weights_attr_values: np.ndarray,
def compute_weights_qparams(weights_attr_data: np.ndarray,
attr_quant_config: 'WeightsAttrQuantizationConfig',
weights_error_method: QuantizationErrorMethod,
l_p_value: int,
output_channels_axis: int,
min_threshold: float,
min_threshold: float = MIN_THRESHOLD,
node=None,
hessian_info_service: HessianInfoService = None,
num_hessian_samples: int = NUM_QPARAM_HESSIAN_SAMPLES) -> Tuple[Dict[Any, Any], int]:
Expand All @@ -40,8 +43,10 @@ def compute_weights_qparams(weights_attr_values: np.ndarray,
instance.

Args:
weights_attr_values: Weights attribute parameter to compute the quantization thresholds for.
weights_attr_data: Weights attribute parameter to compute the quantization thresholds for.
attr_quant_config: A specific weights attribute quantization configuration to get its params.
weights_error_method: quantization error method.
l_p_value: p-norm to use for the Lp-norm distance.
output_channels_axis: Index of the kernel output channels dimension.
min_threshold: Minimal threshold to use if threshold is too small.
node: The node for which the quantization error is computed (used only with HMSE error method).
Expand All @@ -54,13 +59,13 @@ def compute_weights_qparams(weights_attr_values: np.ndarray,
"""
params_fn = _get_weights_quantization_params_fn(attr_quant_config.weights_quantization_method)
weights_params, output_channels_axis = params_fn(
weights_attr_values,
p=attr_quant_config.l_p_value,
weights_attr_data,
p=l_p_value,
n_bits=attr_quant_config.weights_n_bits,
per_channel=attr_quant_config.weights_per_channel_threshold,
channel_axis=output_channels_axis,
min_threshold=min_threshold,
quant_error_method=attr_quant_config.weights_error_method,
quant_error_method=weights_error_method,
node=node,
hessian_info_service=hessian_info_service,
num_hessian_samples=num_hessian_samples)
Expand Down
Loading