Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 0 additions & 6 deletions model_compression_toolkit/core/common/framework_info.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,14 +13,10 @@
# limitations under the License.
# ==============================================================================


from collections.abc import Callable
from enum import Enum
from typing import Dict, Any, Tuple, NamedTuple, Optional
from abc import ABC, abstractmethod

from mct_quantizers import QuantizationMethod


class ChannelAxis(Enum):
"""
Expand Down Expand Up @@ -56,13 +52,11 @@ class FrameworkInfo(ABC):
kernel_ops_attribute_mapping (Dict): Dictionary from a framework operator to its weight attribute to quantize.
out_channel_axis_mapping (Dict): Dictionary of output channels of the model's layers (for computing statistics per-channel).
_layer_min_max_mapping (Dict[Any, tuple]): Dictionary from a layer to its min/max output values.
activation_quantizer_factory_mapping: A mapping from QuantizationMethod to a factory function that accepts activation bitwidth and a dict of quantization params, and returns the corresponding quantization function.
"""

kernel_ops_attribute_mapping: Dict[Any, str]
kernel_channels_mapping: Dict[Any, ChannelAxisMapping]
out_channel_axis_mapping: Dict[Any, int]
activation_quantizer_factory_mapping: Dict[QuantizationMethod, Callable[[int, dict], Callable]]

_layer_min_max_mapping: Dict[Any, tuple]
_default_channel_mapping = ChannelAxisMapping(None, None)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -92,22 +92,24 @@ def init_quantized_weights(node_q_cfg: List[CandidateNodeQuantizationConfig],
return quantized_weights


def init_activation_quantizers(node_q_cfg: List[CandidateNodeQuantizationConfig]) -> List:
def init_activation_quantizers(node_q_cfg: List[CandidateNodeQuantizationConfig],
get_activation_quantization_fn_factory: Callable) -> List:
"""
Builds a list of quantizers for each of the bitwidth candidates for activation quantization,
to be stored and used during MP search.

Args:
node_q_cfg: Quantization configuration candidates of the node that generated the layer that will
use this quantizer.
get_activation_quantization_fn_factory: activation quantization functions factory.

Returns: a list of activation quantizers - for each bitwidth and layer's attribute to be quantized.
"""

activation_quantizers = []
for index, qc in enumerate(node_q_cfg):
q_activation = node_q_cfg[index].activation_quantization_cfg
quantizer = get_activation_quantization_fn(q_activation)
quantizer = get_activation_quantization_fn(q_activation, get_activation_quantization_fn_factory)
activation_quantizers.append(quantizer)

return activation_quantizers
Original file line number Diff line number Diff line change
Expand Up @@ -17,33 +17,21 @@

from mct_quantizers import QuantizationMethod

from model_compression_toolkit.core.common.framework_info import get_fw_info
from model_compression_toolkit.core.common.quantization.node_quantization_config import NodeActivationQuantizationConfig
from model_compression_toolkit.logger import Logger
from model_compression_toolkit.core.common.quantization.quantizers.lut_kmeans_quantizer import lut_kmeans_quantizer
from model_compression_toolkit.core.common.quantization.quantizers.uniform_quantizers import power_of_two_quantizer, \
symmetric_quantizer, uniform_quantizer


def get_activation_quantization_fn_factory(quantization_method: QuantizationMethod) -> Callable[[int, dict], Callable]:
"""
Get factory for activation quantizer.

Args:
quantization_method: quantization method for activation.

Returns:
Factory that accepts activation bitwidth and a dict of quantization params, and returns the quantizer.
"""
return get_fw_info().activation_quantizer_factory_mapping[quantization_method]


def get_activation_quantization_fn(activation_quantization_cfg: NodeActivationQuantizationConfig) -> Callable:
def get_activation_quantization_fn(activation_quantization_cfg: NodeActivationQuantizationConfig,
get_activation_quantization_fn_factory: Callable) -> Callable:
"""
Get activation quantizer based on activation quantization configuration.

Args:
activation_quantization_cfg: activation quantization configuration.
get_activation_quantization_fn_factory: activation quantization functions factory.

Returns:
Activation quantizer that accepts a tensor and returns a quantized tensor.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,8 @@ def compute_activation_bias_correction(graph: Graph,
fw_impl: FrameworkImplementation,
linear_node: BaseNode,
prev_node: BaseNode,
kernel_size: str) -> Graph:
kernel_size: str,
get_activation_quantization_fn_factory: Callable) -> Graph:
"""
Compute the activation bias correction term, and store it in the final activation
quantization configuration.
Expand All @@ -79,6 +80,7 @@ def compute_activation_bias_correction(graph: Graph,
linear_node: Node to compute the activation bias correction for.
prev_node: Node to compute the activation error caused by his activation quantization.
kernel_size: The framework specific attribute name of the convolution layer's kernel size.
get_activation_quantization_fn_factory: activation quantization functions factory.

Returns:
Graph with activation bias correction term for each node.
Expand All @@ -105,7 +107,8 @@ def compute_activation_bias_correction(graph: Graph,
float_centers = calculate_bin_centers(float_bins)

# Quantize the bin edges and calculate the centers of the quantized bins
activation_quantizer = get_activation_quantization_fn(prev_node_act_quant_cfg)
activation_quantizer = get_activation_quantization_fn(prev_node_act_quant_cfg,
get_activation_quantization_fn_factory)
quant_bins = activation_quantizer(fw_impl.to_tensor(float_bins))
quant_bins = fw_impl.to_numpy(quant_bins)
quant_centers = calculate_bin_centers(quant_bins)
Expand Down Expand Up @@ -150,7 +153,8 @@ def compute_activation_bias_correction_of_graph(graph: Graph,
quant_config: QuantizationConfig,
fw_impl: FrameworkImplementation,
activation_bias_correction_node_matchers: Callable,
kernel_size: str) -> Graph:
kernel_size: str,
get_activation_quantization_fn_factory: Callable) -> Graph:
"""
Compute the activation bias correction term for the graph.

Expand All @@ -160,7 +164,7 @@ def compute_activation_bias_correction_of_graph(graph: Graph,
fw_impl: FrameworkImplementation object with a specific framework methods implementation.
activation_bias_correction_node_matchers: Function to match the layers for activation bias correction.
kernel_size: The framework specific attribute name of the convolution layer's kernel size.

get_activation_quantization_fn_factory: activation quantization functions factory.

Returns:
Graph with activation bias correction term for each relevant node.
Expand All @@ -176,5 +180,6 @@ def compute_activation_bias_correction_of_graph(graph: Graph,
fw_impl=fw_impl,
linear_node=n,
prev_node=prev_node,
kernel_size=kernel_size)
kernel_size=kernel_size,
get_activation_quantization_fn_factory=get_activation_quantization_fn_factory)
return graph
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,6 @@

from model_compression_toolkit.core.common.quantization.node_quantization_config import WeightsAttrQuantizationConfig, \
ActivationQuantizationMode
from model_compression_toolkit.core.common.quantization.quantization_fn_selection import \
get_activation_quantization_fn_factory
from model_compression_toolkit.logger import Logger
from model_compression_toolkit.core.common import Graph, BaseNode
from model_compression_toolkit.constants import THRESHOLD, SIGNED, SHIFT_NEGATIVE_NON_LINEAR_NUM_BITS
Expand Down Expand Up @@ -253,6 +251,7 @@ def shift_negative_function(graph: Graph,
padding_str: str,
bias_str: str,
bias_flag_str: str,
get_activation_quantization_fn_factory: Callable,
zero_padding_node: BaseNode = None,
bypass_nodes: List = None,
params_search_quantization_fn: Callable = None
Expand All @@ -278,6 +277,7 @@ def shift_negative_function(graph: Graph,
padding_str: The framework specific attribute name of the padding.
bias_str: The framework specific attribute name of the bias.
bias_flag_str: The framework specific attribute name of the bias flag.
get_activation_quantization_fn_factory: activation quantization functions factory.
zero_padding_node: ZeroPadding2D node that may be in the graph before the linear layer.
params_search_quantization_fn: Function to quantize np tensor using a framework (tf/torch) quantization method. Needed for better param_search estimating the expected loss.

Expand Down Expand Up @@ -335,7 +335,7 @@ def shift_negative_function(graph: Graph,
the histogram (which is a numpy object) is quantized using the non-linear node activation
quantization function (to estimate the expected mse comparing to the original histogram).
The quantization function is a framework function, which makes it fail since it
expects a fw tensor. The commmon part of SNC receives an argument which is a callable
expects a fw tensor. The common part of SNC receives an argument which is a callable
that receives two argument and returns one: it gets the fw activation quantization function
and the bins to quantize. The function (of each fw) responsible for doing (if needed) a preprocessing and postprocessing
to the bins which is a numpy object.
Expand Down Expand Up @@ -569,6 +569,7 @@ def apply_shift_negative_correction(graph: Graph,
padding_str: str,
bias_str: str,
bias_flag_str: str,
get_activation_quantization_fn_factory: Callable,
params_search_quantization_fn: Callable=None) -> Graph:
"""
Apply the substitution even if the linear node is not immediately after
Expand All @@ -590,6 +591,9 @@ def apply_shift_negative_correction(graph: Graph,
padding_str: The framework specific attribute name of the padding.
bias_str: The framework specific attribute name of the bias.
bias_flag_str: The framework specific attribute name of the bias flag.
get_activation_quantization_fn_factory: activation quantization functions factory.
params_search_quantization_fn: Function to quantize np tensor using a framework (tf/torch) quantization method. Needed for better param_search estimating the expected loss.

Returns:
Graph after applying shift negative on selected activations.
"""
Expand Down Expand Up @@ -620,6 +624,7 @@ def apply_shift_negative_correction(graph: Graph,
padding_str,
bias_str,
bias_flag_str,
get_activation_quantization_fn_factory,
zero_padding_node=pad_node,
bypass_nodes=bypass_nodes,
params_search_quantization_fn=params_search_quantization_fn)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,10 +14,10 @@
# ==============================================================================
from typing import List

from model_compression_toolkit.core import FrameworkInfo
from model_compression_toolkit.core import common
from model_compression_toolkit.core.common import BaseNode
from model_compression_toolkit.core.common.quantization.quantization_fn_selection import get_activation_quantization_fn
from model_compression_toolkit.core.keras.quantization.activation_quantization_fn_factory import get_activation_quantization_fn_factory
from model_compression_toolkit.core.keras.back2framework.keras_model_builder import KerasModelBuilder
from tensorflow.python.util.object_identity import Reference as TFReference

Expand Down Expand Up @@ -57,5 +57,6 @@ def _quantize_node_activations(self,
Output of the node.

"""
activation_quantizer = get_activation_quantization_fn(node.final_activation_quantization_cfg)
activation_quantizer = get_activation_quantization_fn(node.final_activation_quantization_cfg,
get_activation_quantization_fn_factory)
return activation_quantizer(input_tensors)
13 changes: 0 additions & 13 deletions model_compression_toolkit/core/keras/default_framework_info.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,19 +18,16 @@
from typing import Tuple, Any, Dict
from functools import wraps

from model_compression_toolkit.core.keras.quantizer.lut_fake_quant import activation_lut_kmean_quantizer
from packaging import version

if version.parse(tf.__version__) >= version.parse("2.13"):
from keras.src.layers import Conv2D, DepthwiseConv2D, Dense, Conv2DTranspose, Softmax, ELU, Activation
else:
from keras.layers import Conv2D, DepthwiseConv2D, Dense, Conv2DTranspose, Softmax, ELU, Activation # pragma: no cover
from model_compression_toolkit.core.common.framework_info import FrameworkInfo, set_fw_info, ChannelAxisMapping
from mct_quantizers import QuantizationMethod
from model_compression_toolkit.constants import SOFTMAX_THRESHOLD, ACTIVATION
from model_compression_toolkit.core.keras.constants import SOFTMAX, LINEAR, RELU, SWISH, SIGMOID, IDENTITY, TANH, SELU, \
KERNEL, DEPTHWISE_KERNEL, GELU
from model_compression_toolkit.core.keras.quantizer.fake_quant_builder import power_of_two_quantization, symmetric_quantization, uniform_quantization


class KerasInfo(FrameworkInfo):
Expand Down Expand Up @@ -69,16 +66,6 @@ class KerasInfo(FrameworkInfo):
Dense: -1,
Conv2DTranspose: -1}

"""
Mapping from a QuantizationMethod to an activation quantizer function.
"""
activation_quantizer_factory_mapping = {
QuantizationMethod.POWER_OF_TWO: power_of_two_quantization,
QuantizationMethod.SYMMETRIC: symmetric_quantization,
QuantizationMethod.UNIFORM: uniform_quantization,
QuantizationMethod.LUT_POT_QUANTIZER: activation_lut_kmean_quantizer
}

"""
Map from an activation function name to its min/max output values (if known).
The values are used for tensor min/max values initialization.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@
NodeFrameworkAttrMatcher
from model_compression_toolkit.core.common.substitutions.shift_negative_activation import \
apply_shift_negative_correction
from model_compression_toolkit.core.keras.quantization.activation_quantization_fn_factory import get_activation_quantization_fn_factory
from model_compression_toolkit.core.keras.constants import KERNEL_SIZE, STRIDES, ACTIVATION, SWISH, \
SELU, GELU, FUNCTION, ADD, PAD
from model_compression_toolkit.core.keras.constants import NEGATIVE_SLOPE, PADDING, PAD_SAME, PAD_VALID, BIAS, USE_BIAS
Expand Down Expand Up @@ -252,5 +253,6 @@ def keras_apply_shift_negative_correction(graph: Graph,
is_padding_node_and_node_has_padding,
PADDING,
BIAS,
USE_BIAS
USE_BIAS,
get_activation_quantization_fn_factory
)
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
verify_candidates_descending_order, init_activation_quantizers
from model_compression_toolkit.core.common.quantization.candidate_node_quantization_config import \
CandidateNodeQuantizationConfig
from model_compression_toolkit.core.keras.quantization.activation_quantization_fn_factory import get_activation_quantization_fn_factory
from model_compression_toolkit.logger import Logger

import tensorflow as tf
Expand Down Expand Up @@ -67,7 +68,7 @@ def __init__(self,
if qc.activation_quantization_cfg.quant_mode != node_q_cfg[0].activation_quantization_cfg.quant_mode:
Logger.critical("Unsupported configuration: Mixing candidates with differing activation quantization states (enabled/disabled).") # pragma: no cover

self.activation_quantizers = init_activation_quantizers(self.node_q_cfg)
self.activation_quantizers = init_activation_quantizers(self.node_q_cfg, get_activation_quantization_fn_factory)
self.active_quantization_config_index = max_candidate_idx # initialize with first config as default

def set_active_activation_quantizer(self, index: Optional[int]):
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
# Copyright 2025 Sony Semiconductor Israel, Inc. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================

from collections.abc import Callable

from mct_quantizers import QuantizationMethod


from model_compression_toolkit.core.keras.quantization.fake_quant_builder import power_of_two_quantization, \
symmetric_quantization, uniform_quantization
from model_compression_toolkit.core.keras.quantization.lut_fake_quant import activation_lut_kmean_quantizer


"""
Mapping from a QuantizationMethod to an activation quantizer function.
"""
_activation_quantizer_factory_mapping = {
QuantizationMethod.POWER_OF_TWO: power_of_two_quantization,
QuantizationMethod.SYMMETRIC: symmetric_quantization,
QuantizationMethod.UNIFORM: uniform_quantization,
QuantizationMethod.LUT_POT_QUANTIZER: activation_lut_kmean_quantizer
}


def get_activation_quantization_fn_factory(quantization_method: QuantizationMethod) -> Callable[[int, dict], Callable]:
"""
Get factory for activation quantizer.

Args:
quantization_method: quantization method for activation.

Returns:
Factory that accepts activation bitwidth and a dict of quantization params, and returns the quantizer.
"""
return _activation_quantizer_factory_mapping[quantization_method]
Loading