Skip to content

Commit fad922f

Browse files
authored
Initial quantization preparation (#1475)
1 parent 0d678a5 commit fad922f

71 files changed

Lines changed: 1239 additions & 1533 deletions

File tree

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

model_compression_toolkit/core/common/back2framework/base_model_builder.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,6 @@
1515
from abc import ABC, abstractmethod
1616
from typing import Any, Tuple
1717

18-
from model_compression_toolkit.core.common.framework_info import FrameworkInfo
1918
from model_compression_toolkit.core import common
2019
from model_compression_toolkit.core.common.user_info import UserInformation
2120

model_compression_toolkit/core/common/framework_info.py

Lines changed: 3 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -16,18 +16,12 @@
1616

1717
from collections.abc import Callable
1818
from enum import Enum
19-
from typing import Dict, Any, Tuple, NamedTuple
19+
from typing import Dict, Any, Tuple, NamedTuple, Optional
2020
from abc import ABC, abstractmethod
2121

2222
from mct_quantizers import QuantizationMethod
2323

2424

25-
# Default value to use for ops without kernel.
26-
# This is a weird default, but it's used all over the place, so for now only extract it to const so that it can be
27-
# referenced by variable instead of hard-coded.
28-
DEFAULT_KERNEL_ATTRIBUTE = None
29-
30-
3125
class ChannelAxis(Enum):
3226
"""
3327
@@ -63,7 +57,6 @@ class FrameworkInfo(ABC):
6357
kernel_ops_attribute_mapping (Dict): Dictionary from a framework operator to its weight attribute to quantize.
6458
out_channel_axis_mapping (Dict): Dictionary of output channels of the model's layers (for computing statistics per-channel).
6559
_layer_min_max_mapping (Dict[Any, tuple]): Dictionary from a layer to its min/max output values.
66-
6760
"""
6861

6962
activation_quantizer_mapping: Dict[QuantizationMethod, Callable]
@@ -75,7 +68,7 @@ class FrameworkInfo(ABC):
7568
_default_channel_mapping = ChannelAxisMapping(None, None)
7669

7770
@classmethod
78-
def get_kernel_op_attribute(cls, node_type: Any) -> str:
71+
def get_kernel_op_attribute(cls, node_type: Any) -> Optional[str]:
7972
"""
8073
Get attribute of a layer's weight to quantize.
8174
@@ -85,20 +78,7 @@ def get_kernel_op_attribute(cls, node_type: Any) -> str:
8578
Returns:
8679
Attribute the layer has and should be quantized.
8780
"""
88-
return cls.kernel_ops_attribute_mapping.get(node_type, DEFAULT_KERNEL_ATTRIBUTE)
89-
90-
@classmethod
91-
def is_kernel_op(cls, node_type: Any) -> bool:
92-
"""
93-
Check is the node is a kernel operation.
94-
95-
Args:
96-
node_type: Layer to get its attributes.
97-
98-
Returns:
99-
True if node type is a kernel operation, else False.
100-
"""
101-
return node_type in cls.kernel_ops_attribute_mapping
81+
return cls.kernel_ops_attribute_mapping.get(node_type)
10282

10383
@classmethod
10484
def get_layer_min_max(cls, layer: Any, fw_attrs: Dict) -> Tuple[float, float]:
@@ -169,7 +149,6 @@ def get_fw_info():
169149
Returns: FrameworkInfo class.
170150
"""
171151
assert _current_framework_info is not None, "fw_info isn't initialized."
172-
assert issubclass(_current_framework_info, FrameworkInfo), "fw_info isn't initialized to a FrameworkInfo class."
173152
return _current_framework_info
174153

175154

model_compression_toolkit/core/common/fusion/graph_fuser.py

Lines changed: 12 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -14,12 +14,12 @@
1414
# ==============================================================================
1515

1616
import copy
17-
from typing import List, Tuple
17+
from typing import Tuple
1818

1919
from model_compression_toolkit.core.common.fusion.fusing_info import FusingInfoGenerator
2020
from model_compression_toolkit.core.common.graph.base_graph import Graph, BaseNode, OutTensor
21-
from model_compression_toolkit.core.common.quantization.candidate_node_quantization_config import CandidateNodeQuantizationConfig
22-
from itertools import product
21+
from model_compression_toolkit.core.common.quantization.candidate_node_quantization_config import \
22+
CandidateNodeQuantizationConfig, NodeQuantizationConfig
2323

2424

2525
class FusedLayerType:
@@ -30,6 +30,7 @@ class FusedLayerType:
3030
def __init__(self):
3131
self.__name__ = 'FusedLayer'
3232

33+
3334
class GraphFuser:
3435
def apply_node_fusion(self, graph: Graph) -> Graph:
3536
"""
@@ -64,7 +65,6 @@ def apply_node_fusion(self, graph: Graph) -> Graph:
6465

6566
return graph_copy
6667

67-
6868
@staticmethod
6969
def _create_fused_node(fused_node_id: str, nodes: Tuple[BaseNode]) -> BaseNode:
7070
"""
@@ -86,10 +86,15 @@ def _create_fused_node(fused_node_id: str, nodes: Tuple[BaseNode]) -> BaseNode:
8686
weights={},
8787
layer_class=FusedLayerType)
8888

89+
base_cfg = CandidateNodeQuantizationConfig(
90+
activation_quantization_cfg=nodes[-1].quantization_cfg.base_quantization_cfg.activation_quantization_cfg,
91+
weights_quantization_cfg=None
92+
)
8993
activation_cfgs = [c.activation_quantization_cfg for c in nodes[-1].candidates_quantization_cfg]
90-
fused_node.candidates_quantization_cfg = [
91-
CandidateNodeQuantizationConfig(weights_quantization_cfg=None, activation_quantization_cfg=a) for a in
92-
activation_cfgs]
94+
candidates = [CandidateNodeQuantizationConfig(weights_quantization_cfg=None, activation_quantization_cfg=a)
95+
for a in activation_cfgs]
96+
fused_node.quantization_cfg = NodeQuantizationConfig(base_quantization_cfg=base_cfg,
97+
candidates_quantization_cfg=candidates)
9398

9499
# Keep the final configurations if they were set already.
95100
fused_node.final_weights_quantization_cfg = nodes[0].final_weights_quantization_cfg
@@ -158,5 +163,3 @@ def _replace_nodes_with_fused_node(graph: Graph,
158163

159164
# Finally, add the new fused node to the graph
160165
graph.add_node(fused_node)
161-
162-

model_compression_toolkit/core/common/graph/base_graph.py

Lines changed: 25 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,7 @@
3939
from model_compression_toolkit.target_platform_capabilities.targetplatform2framework.framework_quantization_capabilities import \
4040
FrameworkQuantizationCapabilities
4141

42+
4243
def validate_graph_after_change(method: Callable) -> Callable:
4344
"""
4445
Decorator for graph-mutating methods. After the decorated method executes,
@@ -120,28 +121,13 @@ def fusing_info(self) -> FusingInfo:
120121
def fusing_info(self, fusing_info: FusingInfo):
121122
self._fusing_info = fusing_info
122123

123-
def set_fqc(self,
124-
fqc: FrameworkQuantizationCapabilities):
124+
def set_fqc(self, fqc: FrameworkQuantizationCapabilities):
125125
"""
126126
Set the graph's FQC.
127127
Args:
128128
fqc: FrameworkQuantizationCapabilities object.
129129
"""
130-
# validate graph nodes are either from the framework or a custom layer defined in the FQC
131-
# Validate graph nodes are either built-in layers from the framework or custom layers defined in the FQC
132-
fqc_layers = fqc.op_sets_to_layers.get_layers()
133-
fqc_filtered_layers = [layer for layer in fqc_layers if isinstance(layer, LayerFilterParams)]
134-
for n in self.nodes:
135-
is_node_in_fqc = any([n.is_match_type(_type) for _type in fqc_layers]) or \
136-
any([n.is_match_filter_params(filtered_layer) for filtered_layer in fqc_filtered_layers])
137-
if n.is_custom:
138-
if not is_node_in_fqc:
139-
Logger.critical(f'MCT does not support optimizing Keras custom layers. Found a layer of type {n.type}. '
140-
' Please add the custom layer to Framework Quantization Capabilities (FQC), or file a feature '
141-
'request or an issue if you believe this should be supported.') # pragma: no cover
142-
if any([qc.default_weight_attr_config.enable_weights_quantization for qc in n.get_qco(fqc).quantization_configurations]):
143-
Logger.critical(f'Layer identified: {n.type}. MCT does not support weight quantization for Keras custom layers.') # pragma: no cover
144-
130+
# TODO irena: this is only passed for negative shift activation.
145131
self.fqc = fqc
146132

147133
def get_topo_sorted_nodes(self):
@@ -578,7 +564,7 @@ def get_weights_configurable_nodes(self,
578564
A list of nodes that their weights can be configured (namely, has one or more weight qc candidate).
579565
"""
580566
# configurability is only relevant for kernel attribute quantization
581-
potential_conf_nodes = [n for n in list(self) if n.is_kernel_op]
567+
potential_conf_nodes = [n for n in self.nodes if n.kernel_attr]
582568

583569
def is_configurable(n):
584570
return n.is_configurable_weight(n.kernel_attr) and (not n.reuse or include_reused_nodes)
@@ -693,10 +679,8 @@ def get_final_weights_config(self) -> List[Tuple[BaseNode, int]]:
693679
"""
694680
Gets the final number of bits for quantization of each weights' configurable layer.
695681
696-
Args:
697-
fw_info: fw_info: FrameworkInfo object with information about the specific framework's model.
698-
699-
Returns: A list of pairs of (node type, node's weights quantization bitwidth).
682+
Returns:
683+
A list of pairs of (node type, node's weights quantization bitwidth).
700684
701685
"""
702686
sorted_conf_weights = self.get_sorted_weights_configurable_nodes()
@@ -876,32 +860,36 @@ def _find_intermediate_and_exit_nodes(self, entry_node: BaseNode, fw_impl: Any)
876860

877861
return intermediate_nodes, next_node
878862

863+
# TODO irena move to load_fqc and clean up tests (currently tests_pytest/common_tests/unit_tests/core/graph/test_base_graph.py)
879864
def override_fused_node_activation_quantization_candidates(self):
880865
"""
881866
Override fused node activation quantization candidates for all nodes in fused operations,
882867
except for the last node in each fused group.
883868
Update the value of quantization_config with the value of op_quaitization_cfg from FusingInfo.
884869
"""
885-
from model_compression_toolkit.core.common.quantization.candidate_node_quantization_config import CandidateNodeQuantizationConfig
886-
887870
nodes_in_fln = self.fusing_info.get_inner_fln_nodes()
888871
for node in nodes_in_fln:
889872
fused_node_op_id = self.fusing_info.get_fused_op_id_for_node(node.name)
890-
fusiong_op_quaitization_cfg = self.fusing_info.get_fused_op_quantization_config(fused_node_op_id)
891-
org_candidate = node.candidates_quantization_cfg[0]
892-
if fusiong_op_quaitization_cfg is not None and fusiong_op_quaitization_cfg.enable_activation_quantization:
893-
# Set ActivationQuantizationMode to FLN_QUANT and update the value of quantization_config
894-
activation_quantization_cfg = NodeActivationQuantizationConfig(qc=org_candidate,
895-
op_cfg=fusiong_op_quaitization_cfg,
896-
activation_quantization_fn=org_candidate.activation_quantization_cfg.activation_quantization_fn,
897-
activation_quantization_params_fn=org_candidate.activation_quantization_cfg.activation_quantization_params_fn)
898-
activation_quantization_cfg.quant_mode = ActivationQuantizationMode.FLN_QUANT
899-
for qc in node.candidates_quantization_cfg:
900-
qc.activation_quantization_cfg = activation_quantization_cfg
873+
fusing_op_quantization_cfg = self.fusing_info.get_fused_op_quantization_config(fused_node_op_id)
874+
if fusing_op_quantization_cfg is not None and fusing_op_quantization_cfg.enable_activation_quantization:
875+
def update(qc):
876+
qc.activation_quantization_cfg = NodeActivationQuantizationConfig(
877+
fusing_op_quantization_cfg,
878+
qc.activation_quantization_cfg.activation_quantization_fn,
879+
qc.activation_quantization_cfg.activation_quantization_params_fn
880+
)
881+
qc.activation_quantization_cfg.quant_mode = ActivationQuantizationMode.FLN_QUANT
882+
node.quantization_cfg.update_all(update)
883+
node.quantization_cfg.remove_duplicates()
901884
else:
902-
# Set ActivationQuantizationMode to FLN_NO_QUANT
885+
node.quantization_cfg.update_activation_quantization_mode(ActivationQuantizationMode.FLN_NO_QUANT)
886+
# Remove duplicate candidates. We cannot compare whole candidates since activation configs might not
887+
# be identical, but we do want to treat them as such. So we only check duplication by weight configs.
888+
uniq_qcs = []
903889
for qc in node.candidates_quantization_cfg:
904-
qc.activation_quantization_cfg.quant_mode = ActivationQuantizationMode.FLN_NO_QUANT
890+
if not any(qc.weights_quantization_cfg == uqc.weights_quantization_cfg for uqc in uniq_qcs):
891+
uniq_qcs.append(qc)
892+
node.quantization_cfg.candidates_quantization_cfg = uniq_qcs
905893

906894
def validate(self):
907895
"""

0 commit comments

Comments
 (0)