Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 0 additions & 4 deletions docsrc/source/api/api_docs/modules/network_editor.rst
Original file line number Diff line number Diff line change
Expand Up @@ -46,10 +46,6 @@ Actions

|

.. autoclass:: model_compression_toolkit.core.network_editor.ChangeQuantizationParamFunction

|

.. autoclass:: model_compression_toolkit.core.network_editor.ChangeFinalWeightsQuantizationMethod

|
Expand Down
8 changes: 4 additions & 4 deletions model_compression_toolkit/core/common/framework_info.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,19 +52,19 @@ class FrameworkInfo(ABC):
no_quantization_ops:Layers that should not get quantized (e.g., Reshape, Transpose, etc.)

Fields:
activation_quantizer_mapping (Dict[QuantizationMethod, Callable]): A dictionary mapping from QuantizationMethod to a quantization function.
kernel_channels_mapping (Dict): Dictionary from a layer to a tuple of its kernel in/out channels indices.
kernel_ops_attribute_mapping (Dict): Dictionary from a framework operator to its weight attribute to quantize.
out_channel_axis_mapping (Dict): Dictionary of output channels of the model's layers (for computing statistics per-channel).
_layer_min_max_mapping (Dict[Any, tuple]): Dictionary from a layer to its min/max output values.
activation_quantizer_factory_mapping: A mapping from QuantizationMethod to a factory function that accepts activation bitwidth and a dict of quantization params, and returns the corresponding quantization function.
"""

activation_quantizer_mapping: Dict[QuantizationMethod, Callable]
kernel_channels_mapping: Dict[Any, ChannelAxisMapping]
kernel_ops_attribute_mapping: Dict[Any, str]
kernel_channels_mapping: Dict[Any, ChannelAxisMapping]
out_channel_axis_mapping: Dict[Any, int]
_layer_min_max_mapping: Dict[Any, tuple]
activation_quantizer_factory_mapping: Dict[QuantizationMethod, Callable[[int, dict], Callable]]

_layer_min_max_mapping: Dict[Any, tuple]
_default_channel_mapping = ChannelAxisMapping(None, None)

@classmethod
Expand Down
6 changes: 1 addition & 5 deletions model_compression_toolkit/core/common/graph/base_graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -873,11 +873,7 @@ def override_fused_node_activation_quantization_candidates(self):
fusing_op_quantization_cfg = self.fusing_info.get_fused_op_quantization_config(fused_node_op_id)
if fusing_op_quantization_cfg is not None and fusing_op_quantization_cfg.enable_activation_quantization:
def update(qc):
qc.activation_quantization_cfg = NodeActivationQuantizationConfig(
fusing_op_quantization_cfg,
qc.activation_quantization_cfg.activation_quantization_fn,
qc.activation_quantization_cfg.activation_quantization_params_fn
)
qc.activation_quantization_cfg = NodeActivationQuantizationConfig(fusing_op_quantization_cfg)
qc.activation_quantization_cfg.quant_mode = ActivationQuantizationMode.FLN_QUANT
node.quantization_cfg.update_all(update)
node.quantization_cfg.remove_duplicates()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,8 @@

from model_compression_toolkit.core.common.quantization.candidate_node_quantization_config import \
CandidateNodeQuantizationConfig
from model_compression_toolkit.core.common.quantization.quantization_fn_selection import (get_activation_quantization_fn,
get_weights_quantization_fn)


def verify_candidates_descending_order(node_q_cfg: List[CandidateNodeQuantizationConfig],
Expand Down Expand Up @@ -77,13 +79,13 @@ def init_quantized_weights(node_q_cfg: List[CandidateNodeQuantizationConfig],
quantized_weights = []
for qc in node_q_cfg:
qc_weights_attr = qc.weights_quantization_cfg.get_attr_config(kernel_attr)
q_weight = qc_weights_attr.weights_quantization_fn(float_weights,
qc_weights_attr.weights_n_bits,
True,
qc_weights_attr.weights_quantization_params,
qc_weights_attr.weights_per_channel_threshold,
qc_weights_attr.weights_channels_axis[
0]) # output channel axis
weights_quantization_fn = get_weights_quantization_fn(qc_weights_attr.weights_quantization_method)
q_weight = weights_quantization_fn(float_weights,
qc_weights_attr.weights_n_bits,
True,
qc_weights_attr.weights_quantization_params,
qc_weights_attr.weights_per_channel_threshold,
qc_weights_attr.weights_channels_axis[0]) # output channel axis

quantized_weights.append(fw_tensor_convert_func(q_weight))

Expand All @@ -105,6 +107,7 @@ def init_activation_quantizers(node_q_cfg: List[CandidateNodeQuantizationConfig]
activation_quantizers = []
for index, qc in enumerate(node_q_cfg):
q_activation = node_q_cfg[index].activation_quantization_cfg
activation_quantizers.append(q_activation.quantize_node_output)
quantizer = get_activation_quantization_fn(q_activation)
activation_quantizers.append(quantizer)

return activation_quantizers
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,14 @@
# limitations under the License.
# ==============================================================================

from model_compression_toolkit.core.common.network_editors.actions import ChangeCandidatesWeightsQuantConfigAttr, ChangeFinalWeightsQuantConfigAttr, ChangeCandidatesActivationQuantConfigAttr, ChangeQuantizationParamFunction, ChangeCandidatesActivationQuantizationMethod, ChangeFinalWeightsQuantizationMethod, ChangeCandidatesWeightsQuantizationMethod, ChangeFinalActivationQuantConfigAttr
from model_compression_toolkit.core.common.network_editors.actions import (
ChangeCandidatesWeightsQuantConfigAttr,
ChangeFinalWeightsQuantConfigAttr,
ChangeCandidatesActivationQuantConfigAttr,
ChangeCandidatesActivationQuantizationMethod,
ChangeFinalWeightsQuantizationMethod,
ChangeCandidatesWeightsQuantizationMethod,
ChangeFinalActivationQuantConfigAttr)
from model_compression_toolkit.core.common.network_editors.actions import EditRule
from model_compression_toolkit.core.common.network_editors.node_filters import NodeTypeFilter, NodeNameScopeFilter, \
NodeNameFilter
98 changes: 3 additions & 95 deletions model_compression_toolkit/core/common/network_editors/actions.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,14 +20,8 @@
from mct_quantizers import QuantizationMethod
from model_compression_toolkit.core.common import Graph
from model_compression_toolkit.logger import Logger


from model_compression_toolkit.core.common.framework_info import get_fw_info
from model_compression_toolkit.core.common.graph.base_node import BaseNode
from model_compression_toolkit.core.common.quantization.quantization_params_fn_selection import \
get_activation_quantization_params_fn, get_weights_quantization_params_fn
from model_compression_toolkit.core.common.quantization.quantization_fn_selection import \
get_weights_quantization_fn


_EditRule = namedtuple('EditRule', 'filter action')

Expand Down Expand Up @@ -174,47 +168,6 @@ def apply(self, node: BaseNode, graph):
node.final_activation_quantization_cfg.set_quant_config_attr(parameter_name, parameter_value)


class ChangeQuantizationParamFunction(BaseAction):
"""
Class ChangeQuantizationParamFunction to change a node's weights/activations quantization params function.
"""

def __init__(self,
attr_name: str = None,
activation_quantization_params_fn: Callable = None,
weights_quantization_params_fn: Callable = None):
"""
Init a ChangeQuantizationParamFunction object.

Args:
attr_name: The weights attribute's name to set the weights quantization params function for (if setting weights params).
activation_quantization_params_fn: a params function for a node's activations.
weights_quantization_params_fn: a params function for a node's weights.
"""
self.activation_quantization_params_fn = activation_quantization_params_fn
self.weights_quantization_params_fn = weights_quantization_params_fn
self.attr_name = attr_name

def apply(self, node: BaseNode, graph):
"""
Change the node's weights/activations quantization params function.

Args:
node: Node object to change its quantization params function.
graph: Graph to apply the action on.

Returns:
The node after its quantization params function has been modified.
"""
for nqc in node.candidates_quantization_cfg:
if self.activation_quantization_params_fn is not None:
nqc.activation_quantization_cfg.set_activation_quantization_params_fn(
self.activation_quantization_params_fn)
if self.weights_quantization_params_fn is not None:
attr_config = nqc.weights_quantization_cfg.get_attr_config(self.attr_name)
attr_config.override_weights_quantization_params_fn(self.weights_quantization_params_fn)


class ChangeFinalActivationQuantizationMethod(BaseAction):
"""
Class ChangeFinalActivationQuantizationMethod to change a node's weights/activations quantizer function.
Expand Down Expand Up @@ -243,16 +196,6 @@ def apply(self, node: BaseNode, graph):
"""

if self.activation_quantization_method is not None and node.final_activation_quantization_cfg is not None:

activation_quantization_params_fn = get_activation_quantization_params_fn(
self.activation_quantization_method)

node.final_activation_quantization_cfg.set_activation_quantization_params_fn(
activation_quantization_params_fn)

activation_quantization_fn = get_fw_info().activation_quantizer_mapping.get(self.activation_quantization_method)

node.final_activation_quantization_cfg.set_activation_quantization_fn(activation_quantization_fn)
node.final_activation_quantization_cfg.activation_quantization_method = self.activation_quantization_method


Expand Down Expand Up @@ -281,23 +224,12 @@ def apply(self, node: BaseNode, graph):
"""
if self.activation_quantization_method is not None:
for qc in node.candidates_quantization_cfg:
activation_quantization_params_fn = get_activation_quantization_params_fn(
self.activation_quantization_method)

qc.activation_quantization_cfg.set_activation_quantization_params_fn(activation_quantization_params_fn)
activation_quantization_fn = get_fw_info().activation_quantizer_mapping.get(
self.activation_quantization_method)

if activation_quantization_fn is None:
Logger.critical('Unknown activation quantization method specified.') # pragma: no cover

qc.activation_quantization_cfg.set_activation_quantization_fn(activation_quantization_fn)
qc.activation_quantization_cfg.activation_quantization_method = self.activation_quantization_method


class ChangeFinalWeightsQuantizationMethod(BaseAction):
"""
Class ChangeFinalWeightsQuantizationMethod to change a node's weights/activations quantizer function.
Class ChangeFinalWeightsQuantizationMethod to change a node's weights/activations quantizer method.
"""

def __init__(self, attr_name: str, weights_quantization_method=None):
Expand All @@ -323,21 +255,8 @@ def apply(self, node: BaseNode, graph):
"""

if self.weights_quantization_method is not None and node.final_weights_quantization_cfg is not None:

weights_quantization_params_fn = get_weights_quantization_params_fn(self.weights_quantization_method)

attr_config = node.final_weights_quantization_cfg.get_attr_config(self.attr_name)
attr_config.override_weights_quantization_params_fn(weights_quantization_params_fn)

weights_quantization_fn = get_weights_quantization_fn(self.weights_quantization_method)

if weights_quantization_fn is None:
Logger.critical('Unknown weights quantization method specified.') # pragma: no cover

attr_config = node.final_weights_quantization_cfg.get_attr_config(self.attr_name)
attr_config.override_weights_quantization_fn(weights_quantization_fn)
node.final_weights_quantization_cfg.get_attr_config(self.attr_name).weights_quantization_method = \
self.weights_quantization_method
attr_config.weights_quantization_method = self.weights_quantization_method


class ChangeCandidatesWeightsQuantizationMethod(BaseAction):
Expand Down Expand Up @@ -370,18 +289,7 @@ def apply(self, node: BaseNode, graph: Graph):

if self.weights_quantization_method is not None:
for qc in node.candidates_quantization_cfg:

weights_quantization_params_fn = get_weights_quantization_params_fn(self.weights_quantization_method)

attr_qc = qc.weights_quantization_cfg.get_attr_config(self.attr_name)
attr_qc.override_weights_quantization_params_fn(weights_quantization_params_fn)

weights_quantization_fn = get_weights_quantization_fn(self.weights_quantization_method)

if weights_quantization_fn is None:
Logger.critical('Unknown weights quantization method specified.') # pragma: no cover

attr_qc.override_weights_quantization_fn(weights_quantization_fn)
attr_qc.weights_quantization_method = self.weights_quantization_method


Expand Down
Loading