diff --git a/model_compression_toolkit/constants.py b/model_compression_toolkit/constants.py index 93c62012b..59f790936 100644 --- a/model_compression_toolkit/constants.py +++ b/model_compression_toolkit/constants.py @@ -138,3 +138,8 @@ NODE_NAME = 'node_name' TOTAL_SIZE = 'total_size' NODE_OUTPUT_INDEX = 'node_output_index' + + +# Fusing Patterns constants +FUSED_LAYER_PATTERN = 'fused_layer_pattern' +FUSED_OP_QUANT_CONFIG = 'fused_op_quantization_config' \ No newline at end of file diff --git a/model_compression_toolkit/core/common/fusion/fusing_info.py b/model_compression_toolkit/core/common/fusion/fusing_info.py index 81e2fe806..08c8656aa 100644 --- a/model_compression_toolkit/core/common/fusion/fusing_info.py +++ b/model_compression_toolkit/core/common/fusion/fusing_info.py @@ -14,6 +14,8 @@ # ============================================================================== from model_compression_toolkit.target_platform_capabilities import LayerFilterParams +from model_compression_toolkit.target_platform_capabilities.schema.mct_current_schema import OpQuantizationConfig +from model_compression_toolkit.constants import FUSED_LAYER_PATTERN, FUSED_OP_QUANT_CONFIG from dataclasses import dataclass, field from typing import Optional, List, Dict, Any, Tuple @@ -41,6 +43,7 @@ class FusingInfo: fusing_patterns: any = None fusing_data: Dict[str, Tuple['BaseNode']] = field(default_factory=dict) node_to_fused_node_map: Dict[str, str] = field(init=False, default_factory=dict) + fusing_data_to_quantization_config_map: Dict[str, OpQuantizationConfig] = field(default_factory=dict) def __post_init__(self): """Validates and initializes mappings after dataclass instantiation.""" @@ -49,6 +52,7 @@ def __post_init__(self): assert isinstance(op_nodes, tuple) and len(op_nodes) > 1, f"Found invalid fused op nodes: {op_nodes}" self._init_node_mapping() + self._init_quantization_config_map() def _init_node_mapping(self) -> None: """ @@ -59,6 +63,17 @@ def _init_node_mapping(self) -> None: for node in nodes: self.node_to_fused_node_map[node.name] = op_id + def _init_quantization_config_map(self) -> None: + """ + Init the mapping between fused operation IDs and their quantization configurations. + """ + self.fusing_data_to_quantization_config_map.clear() + if self.fusing_patterns is not None: + for op_id, nodes in self.fusing_data.items(): + for fusing_pattern in self.fusing_patterns: + if is_valid_fusion([fusing_pattern], nodes): + self.fusing_data_to_quantization_config_map[op_id] = fusing_pattern.get(FUSED_OP_QUANT_CONFIG) + def add_fused_operation(self, op_id: str, nodes: Tuple['BaseNode']) -> None: """ Add a new fused operation with the given ID and set of nodes. @@ -78,6 +93,12 @@ def add_fused_operation(self, op_id: str, nodes: Tuple['BaseNode']) -> None: for node in nodes: self.node_to_fused_node_map[node.name] = op_id + # Update the quantization config mapping for this operation + if self.fusing_patterns is not None: + for fusing_pattern in self.fusing_patterns: + if is_valid_fusion([fusing_pattern], nodes): + self.fusing_data_to_quantization_config_map[op_id] = fusing_pattern.get(FUSED_OP_QUANT_CONFIG) + def remove_fused_operation(self, op_id: str) -> None: """ Remove a fused operation by its ID. @@ -95,6 +116,7 @@ def remove_fused_operation(self, op_id: str) -> None: for node in nodes: self.node_to_fused_node_map.pop(node.name, None) del self.fusing_data[op_id] + del self.fusing_data_to_quantization_config_map[op_id] def get_fused_node_name(self, node_name: str) -> Optional[str]: """ @@ -117,6 +139,15 @@ def get_node_to_fused_node_map(self) -> Dict[str, str]: """ return self.node_to_fused_node_map.copy() + def get_fusing_quantization_config_map(self) -> Dict[str, OpQuantizationConfig]: + """ + Retrieve a copy of the mapping from fused operation IDs to their quantization configurations. + + Returns: + A dictionary mapping each fused operation ID to its quantization configuration. + """ + return self.fusing_data_to_quantization_config_map.copy() + def get_fused_nodes(self, op_id: str) -> Optional[List['BaseNode']]: """ Retrieve the list of nodes for a given fused operation ID. @@ -129,6 +160,18 @@ def get_fused_nodes(self, op_id: str) -> Optional[List['BaseNode']]: """ return self.fusing_data.get(op_id) + def get_fused_op_quantization_config(self, op_id: str) -> OpQuantizationConfig: + """ + Retrieve the quantization configuration for a given fused operation ID. + + Args: + op_id (str): The identifier for the fused operation. + + Returns: + OpQuantizationConfig: The quantization configuration for the operation, or None if not found. + """ + return self.fusing_data_to_quantization_config_map.get(op_id) + def is_node_in_fused_op(self, node: 'BaseNode') -> bool: """ Check if a node is part of any fused operation. @@ -284,11 +327,13 @@ def generate_fusing_info(self, graph: 'Graph') -> FusingInfo: - Fusions are linear sequences (each node has exactly one successor). - Each node belongs to at most one fused operation. """ - if not self._fusing_patterns: + fusing_layer_patterns = [f.get(FUSED_LAYER_PATTERN) for f in self._fusing_patterns] + + if not fusing_layer_patterns: return FusingInfo(fusing_patterns=self._fusing_patterns) # Find max fusion - max_layers_fusing = max([len(fusing_pattern) for fusing_pattern in self._fusing_patterns]) + max_fusing_layer_patterns = max([len(fusing_pattern.get(FUSED_LAYER_PATTERN)) for fusing_pattern in self._fusing_patterns]) # Travel along the graph to find layers for fusing nodes = graph.get_topo_sorted_nodes() @@ -302,9 +347,9 @@ def generate_fusing_info(self, graph: 'Graph') -> FusingInfo: continue # Start fusing search fusing_nodes = [] # nodes that are candidates for participating in fusing - patterns = copy.deepcopy(self._fusing_patterns) + patterns = copy.deepcopy(fusing_layer_patterns) next_nodes = [node] - for i in range(max_layers_fusing): + for i in range(max_fusing_layer_patterns): patterns = get_valid_fusing_patterns_for_node(patterns, next_nodes[0], i) if len(patterns) == 0: # Give up if no more fusion pattern break @@ -361,10 +406,11 @@ def is_valid_fusion(fusing_patterns: List[List[Any]], nodes: List['BaseNode']) - if fusion_depth <= 1: return False for fusing_pattern in fusing_patterns: - if fusion_depth != len(fusing_pattern): + fusing_layer_pattern = fusing_pattern.get(FUSED_LAYER_PATTERN) + if fusion_depth != len(fusing_layer_pattern): continue counter = 0 - for i, layer in enumerate(fusing_pattern): + for i, layer in enumerate(fusing_layer_pattern): if (type(layer) == LayerFilterParams and nodes[i].is_match_filter_params(layer)) or \ nodes[i].is_match_type(layer): counter += 1 diff --git a/model_compression_toolkit/target_platform_capabilities/targetplatform2framework/framework_quantization_capabilities.py b/model_compression_toolkit/target_platform_capabilities/targetplatform2framework/framework_quantization_capabilities.py index 6e78b3821..e49cb5e0d 100644 --- a/model_compression_toolkit/target_platform_capabilities/targetplatform2framework/framework_quantization_capabilities.py +++ b/model_compression_toolkit/target_platform_capabilities/targetplatform2framework/framework_quantization_capabilities.py @@ -31,6 +31,9 @@ OpQuantizationConfig, QuantizationConfigOptions from model_compression_toolkit.target_platform_capabilities.targetplatform2framework.current_tpc import _current_tpc +from model_compression_toolkit.constants import FUSED_LAYER_PATTERN, FUSED_OP_QUANT_CONFIG + + class FrameworkQuantizationCapabilities(ImmutableClass): """ Attach framework information to a modeled hardware. @@ -97,17 +100,30 @@ def get_layers_by_opset(self, op: OperatorsSetBase) -> List[Any]: def get_fusing_patterns(self) -> List[List[Any]]: """ - Returns: List of patterns of layers/LayerFilterParams to fuse. + Returns: List of patterns of layers/LayerFilterParams to fuse and their mapping quantization config. """ - res = [] + + patterns = [] if self.tpc.fusing_patterns is None: - return res + return patterns + for p in self.tpc.fusing_patterns: + res = [] ops = [self.get_layers_by_opset(x) for x in p.operator_groups] res.extend(itertools.product(*ops)) - return [list(x) for x in res] + if hasattr(p, 'fused_op_quantization_config'): + fused_op_quantization_config = p.fused_op_quantization_config + else: + fused_op_quantization_config = None + + for x in res: + pattern = {FUSED_LAYER_PATTERN: list(x), + FUSED_OP_QUANT_CONFIG: fused_op_quantization_config} + patterns.append(pattern) + + return patterns def get_info(self) -> Dict[str, Any]: """ diff --git a/tests_pytest/common_tests/unit_tests/core/test_fusion_info.py b/tests_pytest/common_tests/unit_tests/core/test_fusion_info.py index 2cf2d237a..186b7bb87 100644 --- a/tests_pytest/common_tests/unit_tests/core/test_fusion_info.py +++ b/tests_pytest/common_tests/unit_tests/core/test_fusion_info.py @@ -19,6 +19,12 @@ from model_compression_toolkit.core.common.fusion.fusing_info import FusingInfoGenerator, FUSED_OP_ID_PREFIX, FusingInfo from model_compression_toolkit.target_platform_capabilities import FrameworkQuantizationCapabilities from model_compression_toolkit.core.common import BaseNode +from model_compression_toolkit.constants import FUSED_LAYER_PATTERN, FUSED_OP_QUANT_CONFIG + +from tests.common_tests.helpers.generate_test_tpc import generate_test_attr_configs, generate_test_op_qc + +# Setup TEST_QC and TEST_QCO for testing. +TEST_QC = generate_test_op_qc(**generate_test_attr_configs()) class MockBaseNode: @@ -42,7 +48,8 @@ def fusing_patterns(): """ - Returns predefined fusing patterns: Conv2D + ReLU and Linear + Softmax. """ - return [["Conv2d", "ReLU"], ["Linear", "Softmax"]] + return [{FUSED_LAYER_PATTERN: ["Conv2d", "ReLU"], FUSED_OP_QUANT_CONFIG: None}, + {FUSED_LAYER_PATTERN: ["Linear", "Softmax"], FUSED_OP_QUANT_CONFIG: None}] @pytest.fixture @@ -219,3 +226,160 @@ def test_is_node_in_fused_op_returns_false_for_absent_node(mock_graph, fusing_in unrelated = MockBaseNode("unrelated") assert not fi.is_node_in_fused_op(unrelated) + + +def create_mock_base_node(name: str, layer_class: str): + """ + Function for creating the mock nodes required for a simple neural network structure. + Enables node name, layer class, type, and type checking method. + """ + + dummy_initalize = {'framework_attr': {}, + 'input_shape': (), + 'output_shape': (), + 'weights': {}} + + real_node = BaseNode(name=name, layer_class=layer_class, **dummy_initalize) + + node = Mock(spec=real_node) + node.is_match_type = real_node.is_match_type + node.layer_class = layer_class + node.name = name + + return node + +@pytest.fixture +def fusing_patterns_with_qconfig(): + """ + - Returns predefined fusing patterns: Conv2D + ReLU and Conv2D + Tanh, Linear + Softmax. + """ + return [{FUSED_LAYER_PATTERN: ["Conv2d", "ReLU"], FUSED_OP_QUANT_CONFIG: TEST_QC}, + {FUSED_LAYER_PATTERN: ["Conv2d", "Tanh"], FUSED_OP_QUANT_CONFIG: None}, + {FUSED_LAYER_PATTERN: ["Linear", "Softmax"], FUSED_OP_QUANT_CONFIG: TEST_QC }] + +@pytest.fixture +def fusing_info_generator_with_qconfig(fusing_patterns_with_qconfig): + """ + Creates a FusingInfoGenerator using the fusing patterns. + """ + return FusingInfoGenerator(fusing_patterns_with_qconfig) + +@pytest.fixture +def mock_qconfig_set_nodes(): + """ + Creates mock nodes representing a simple neural network structure. + - Nodes: Conv2D, ReLU, Conv2D, Tanh, Linear, Softmax. + """ + node1 = create_mock_base_node(name='conv', layer_class='Conv2d') + node2 = create_mock_base_node(name='relu', layer_class='ReLU') + node3 = create_mock_base_node(name='conv_2', layer_class='Conv2d') + node4 = create_mock_base_node(name='tanh', layer_class='Tanh') + node5 = create_mock_base_node(name='linear', layer_class='Linear') + node6 = create_mock_base_node(name='softmax', layer_class='Softmax') + + return [node1, node2, node3, node4, node5, node6] + + +@pytest.fixture +def mock_qconfig_set_graph(mock_qconfig_set_nodes): + """ + Creates a mock graph with topologically sorted nodes and defined connectivity. + - Implements `get_next_nodes` and `get_prev_nodes` to maintain linear order. + """ + mock_nodes = mock_qconfig_set_nodes + + graph = Mock() + graph.nodes.return_value = mock_nodes + graph.get_topo_sorted_nodes.return_value = mock_nodes + + adjacency = { + mock_nodes[0]: [mock_nodes[1]], # conv -> relu + mock_nodes[1]: [mock_nodes[2]], # relu -> conv_2 + mock_nodes[2]: [mock_nodes[3]], # conv_2 -> silu + mock_nodes[3]: [mock_nodes[4]], # silu -> linear + mock_nodes[4]: [mock_nodes[5]], # linear -> softmax + mock_nodes[5]: [] # softmax has no outputs + } + + reverse_adjacency = { + mock_nodes[0]: [], # conv has no inputs + mock_nodes[1]: [mock_nodes[0]], # relu <- conv + mock_nodes[2]: [mock_nodes[1]], # conv_2 <- relu + mock_nodes[3]: [mock_nodes[2]], # silu <- conv_2 + mock_nodes[4]: [mock_nodes[3]], # linear <- silu + mock_nodes[5]: [mock_nodes[4]] # softmax <- linear + } + + graph.get_next_nodes.side_effect = lambda node: adjacency.get(node, []) + graph.get_prev_nodes.side_effect = lambda node: reverse_adjacency.get(node, []) + + return graph + + +def test_fusing_info_qconfig_mapping(mock_qconfig_set_graph, fusing_info_generator_with_qconfig): + """ + Tests that each node is correctly mapped to its fused quantization configs. + """ + fi = fusing_info_generator_with_qconfig.generate_fusing_info(mock_qconfig_set_graph) + fi_qconfig_map = fi.get_fusing_quantization_config_map() + + expected_op1_id = f"{FUSED_OP_ID_PREFIX}conv_relu" + expected_op2_id = f"{FUSED_OP_ID_PREFIX}conv_2_tanh" + expected_op3_id = f"{FUSED_OP_ID_PREFIX}linear_softmax" + + assert len(fi_qconfig_map) == 3 + assert fi_qconfig_map[expected_op1_id] == TEST_QC + assert fi_qconfig_map[expected_op2_id] == None + assert fi_qconfig_map[expected_op3_id] == TEST_QC + + +def test_add_fused_operation_adds_data_and_qconfig(mock_qconfig_set_graph, fusing_info_generator_with_qconfig): + """ + Tests whether the added node is correctly assigned the fused quantization config. + """ + + fi = fusing_info_generator_with_qconfig.generate_fusing_info(mock_qconfig_set_graph) + fi_qconfig_map = fi.get_fusing_quantization_config_map() + + ### Checking the number of mappings before addition + assert len(fi_qconfig_map) == 3 + + node1 = create_mock_base_node(name='conv_a', layer_class='Conv2d') + node2 = create_mock_base_node(name='relu_b', layer_class='ReLU') + + op_id = f"{FUSED_OP_ID_PREFIX}conv_a_relu_b" + fi.add_fused_operation(op_id, (node1, node2)) + fi_qconfig_map = fi.get_fusing_quantization_config_map() + + ### Checking the mapping information after addition + assert op_id in fi.get_all_fused_operations() + assert fi.get_fused_node_name("conv_a") == op_id + assert fi.get_fused_node_name("relu_b") == op_id + + assert len(fi_qconfig_map) == 4 + assert fi.get_fused_op_quantization_config(op_id) == TEST_QC + + +def test_remove_fusing_data_and_qconfig(mock_qconfig_set_graph, fusing_info_generator_with_qconfig, mock_qconfig_set_nodes): + """ + Tests that the fused quantization config for the specified operation is removed from the map. + """ + + fi = fusing_info_generator_with_qconfig.generate_fusing_info(mock_qconfig_set_graph) + + ### Delete Conv2D + ReLU pattern. + conv_node, relu_node, _, _, _, _ = mock_qconfig_set_nodes + op_id = f"{FUSED_OP_ID_PREFIX}conv_relu" + + ### Checking the mapping information before deletion. + assert len(fi.get_fusing_quantization_config_map()) == 3 + assert fi.get_fused_op_quantization_config(op_id) == TEST_QC + assert fi.get_fused_nodes(op_id) == (conv_node, relu_node) + + fi.remove_fused_operation(op_id) + fi_qconfig_map = fi.get_fusing_quantization_config_map() + + ### Checking the mapping information after deletion. + assert len(fi.get_fusing_quantization_config_map()) == 2 + assert fi.get_fused_op_quantization_config(op_id) == None + assert fi.get_fused_nodes(op_id) == None