diff --git a/model_compression_toolkit/constants.py b/model_compression_toolkit/constants.py index 93c62012b..59f790936 100644 --- a/model_compression_toolkit/constants.py +++ b/model_compression_toolkit/constants.py @@ -138,3 +138,8 @@ NODE_NAME = 'node_name' TOTAL_SIZE = 'total_size' NODE_OUTPUT_INDEX = 'node_output_index' + + +# Fusing Patterns constants +FUSED_LAYER_PATTERN = 'fused_layer_pattern' +FUSED_OP_QUANT_CONFIG = 'fused_op_quantization_config' \ No newline at end of file diff --git a/model_compression_toolkit/core/common/fusion/fusing_info.py b/model_compression_toolkit/core/common/fusion/fusing_info.py index 81e2fe806..c1a059f41 100644 --- a/model_compression_toolkit/core/common/fusion/fusing_info.py +++ b/model_compression_toolkit/core/common/fusion/fusing_info.py @@ -14,6 +14,8 @@ # ============================================================================== from model_compression_toolkit.target_platform_capabilities import LayerFilterParams +from model_compression_toolkit.target_platform_capabilities.schema.mct_current_schema import OpQuantizationConfig +from model_compression_toolkit.constants import FUSED_LAYER_PATTERN, FUSED_OP_QUANT_CONFIG from dataclasses import dataclass, field from typing import Optional, List, Dict, Any, Tuple @@ -41,6 +43,7 @@ class FusingInfo: fusing_patterns: any = None fusing_data: Dict[str, Tuple['BaseNode']] = field(default_factory=dict) node_to_fused_node_map: Dict[str, str] = field(init=False, default_factory=dict) + fused_op_id_to_quant_config: Dict[str, OpQuantizationConfig] = field(default_factory=dict) def __post_init__(self): """Validates and initializes mappings after dataclass instantiation.""" @@ -49,6 +52,7 @@ def __post_init__(self): assert isinstance(op_nodes, tuple) and len(op_nodes) > 1, f"Found invalid fused op nodes: {op_nodes}" self._init_node_mapping() + self._init_quantization_config_map() def _init_node_mapping(self) -> None: """ @@ -59,6 +63,15 @@ def _init_node_mapping(self) -> None: for node in nodes: self.node_to_fused_node_map[node.name] = op_id + def _init_quantization_config_map(self) -> None: + """ + Init the mapping between fused operation IDs and their quantization configurations. + """ + self.fused_op_id_to_quant_config.clear() + if self.fusing_patterns is not None: + for op_id, nodes in self.fusing_data.items(): + self.set_fused_op_quantization_config(op_id, nodes) + def add_fused_operation(self, op_id: str, nodes: Tuple['BaseNode']) -> None: """ Add a new fused operation with the given ID and set of nodes. @@ -78,6 +91,22 @@ def add_fused_operation(self, op_id: str, nodes: Tuple['BaseNode']) -> None: for node in nodes: self.node_to_fused_node_map[node.name] = op_id + # Update the quantization config mapping for this operation + if self.fusing_patterns is not None: + self.set_fused_op_quantization_config(op_id, nodes) + + def set_fused_op_quantization_config(self, op_id: str, nodes: Tuple['BaseNode']) -> None: + """ + Set the quantization configuration for a given fused operation ID. + + Args: + op_id (str): The identifier for the fused operation. + nodes (Tuple[BaseNode]): The tuple of nodes that form the fused operation. + """ + fusing_pattern = next((fp for fp in self.fusing_patterns if is_valid_fusion([fp.get(FUSED_LAYER_PATTERN)], nodes)), None) + if fusing_pattern is not None: + self.fused_op_id_to_quant_config[op_id] = fusing_pattern.get(FUSED_OP_QUANT_CONFIG) + def remove_fused_operation(self, op_id: str) -> None: """ Remove a fused operation by its ID. @@ -95,6 +124,7 @@ def remove_fused_operation(self, op_id: str) -> None: for node in nodes: self.node_to_fused_node_map.pop(node.name, None) del self.fusing_data[op_id] + self.fused_op_id_to_quant_config.pop(op_id, None) def get_fused_node_name(self, node_name: str) -> Optional[str]: """ @@ -117,6 +147,15 @@ def get_node_to_fused_node_map(self) -> Dict[str, str]: """ return self.node_to_fused_node_map.copy() + def get_fusing_quantization_config_map(self) -> Dict[str, OpQuantizationConfig]: + """ + Retrieve a copy of the mapping from fused operation IDs to their quantization configurations. + + Returns: + A dictionary mapping each fused operation ID to its quantization configuration. + """ + return self.fused_op_id_to_quant_config.copy() + def get_fused_nodes(self, op_id: str) -> Optional[List['BaseNode']]: """ Retrieve the list of nodes for a given fused operation ID. @@ -129,6 +168,18 @@ def get_fused_nodes(self, op_id: str) -> Optional[List['BaseNode']]: """ return self.fusing_data.get(op_id) + def get_fused_op_quantization_config(self, op_id: str) -> OpQuantizationConfig: + """ + Retrieve the quantization configuration for a given fused operation ID. + + Args: + op_id (str): The identifier for the fused operation. + + Returns: + OpQuantizationConfig: The quantization configuration for the operation, or None if not found. + """ + return self.fused_op_id_to_quant_config.get(op_id) + def is_node_in_fused_op(self, node: 'BaseNode') -> bool: """ Check if a node is part of any fused operation. @@ -216,10 +267,11 @@ def validate(self, graph: 'Graph') -> None: all_fused_nodes.update(node_set) # Check 4: Ensure the sequence matches a valid fusing pattern - if not is_valid_fusion(self.fusing_patterns, nodes): + valid_fusing_patterns = _get_fusing_layer_patterns(self.fusing_patterns) + if not is_valid_fusion(valid_fusing_patterns, nodes): raise ValueError( f"Fused operation {op_id} does not match any valid fusing pattern " - f"from {self.fusing_patterns}." + f"from {valid_fusing_patterns}." ) def is_nodes_eligible_to_be_fused(self, nodes: List['BaseNode']) -> bool: @@ -240,7 +292,8 @@ def is_nodes_eligible_to_be_fused(self, nodes: List['BaseNode']) -> bool: return False # Check if the provided nodes match a valid fusion pattern - return is_valid_fusion(fusing_patterns=self.fusing_patterns, nodes=nodes) + valid_fusing_patterns = _get_fusing_layer_patterns(self.fusing_patterns) + return is_valid_fusion(fusing_patterns=valid_fusing_patterns, nodes=nodes) def __repr__(self) -> str: """ @@ -287,8 +340,11 @@ def generate_fusing_info(self, graph: 'Graph') -> FusingInfo: if not self._fusing_patterns: return FusingInfo(fusing_patterns=self._fusing_patterns) + # Extract fusing layer patterns + fusing_layer_patterns = _get_fusing_layer_patterns(self._fusing_patterns) + # Find max fusion - max_layers_fusing = max([len(fusing_pattern) for fusing_pattern in self._fusing_patterns]) + max_layer_patterns = max([len(fusing_layer_pattern) for fusing_layer_pattern in fusing_layer_patterns]) # Travel along the graph to find layers for fusing nodes = graph.get_topo_sorted_nodes() @@ -302,9 +358,9 @@ def generate_fusing_info(self, graph: 'Graph') -> FusingInfo: continue # Start fusing search fusing_nodes = [] # nodes that are candidates for participating in fusing - patterns = copy.deepcopy(self._fusing_patterns) + patterns = copy.deepcopy(fusing_layer_patterns) next_nodes = [node] - for i in range(max_layers_fusing): + for i in range(max_layer_patterns): patterns = get_valid_fusing_patterns_for_node(patterns, next_nodes[0], i) if len(patterns) == 0: # Give up if no more fusion pattern break @@ -314,7 +370,7 @@ def generate_fusing_info(self, graph: 'Graph') -> FusingInfo: break # New fusion - if is_valid_fusion(self._fusing_patterns, fusing_nodes): + if is_valid_fusion(fusing_layer_patterns, fusing_nodes): fused_op_id = FusingInfo.generate_fused_op_id(fusing_nodes) assert fused_op_id not in fusing_info, f"{fused_op_id} is already in fusing info: {fusing_info}" fusing_info[fused_op_id] = tuple(fusing_nodes) @@ -371,3 +427,15 @@ def is_valid_fusion(fusing_patterns: List[List[Any]], nodes: List['BaseNode']) - if counter == fusion_depth: return True return False + + +def _get_fusing_layer_patterns(fusing_patterns: List[Dict[Any, OpQuantizationConfig]]) -> List[List[Any]]: + """ + Extracts the fusing layer patterns from the provided fusing patterns. + Args: + fusing_patterns: List of patterns of layers/LayerFilterParams to fuse and their mapping quantization config. + + Returns: + supported fusing layer patterns + """ + return [f.get(FUSED_LAYER_PATTERN) for f in fusing_patterns] diff --git a/model_compression_toolkit/target_platform_capabilities/targetplatform2framework/framework_quantization_capabilities.py b/model_compression_toolkit/target_platform_capabilities/targetplatform2framework/framework_quantization_capabilities.py index 6e78b3821..d763d0911 100644 --- a/model_compression_toolkit/target_platform_capabilities/targetplatform2framework/framework_quantization_capabilities.py +++ b/model_compression_toolkit/target_platform_capabilities/targetplatform2framework/framework_quantization_capabilities.py @@ -31,6 +31,9 @@ OpQuantizationConfig, QuantizationConfigOptions from model_compression_toolkit.target_platform_capabilities.targetplatform2framework.current_tpc import _current_tpc +from model_compression_toolkit.constants import FUSED_LAYER_PATTERN, FUSED_OP_QUANT_CONFIG + + class FrameworkQuantizationCapabilities(ImmutableClass): """ Attach framework information to a modeled hardware. @@ -94,20 +97,26 @@ def get_layers_by_opset(self, op: OperatorsSetBase) -> List[Any]: """ return self.op_sets_to_layers.get_layers_by_op(op) - def get_fusing_patterns(self) -> List[List[Any]]: + def get_fusing_patterns(self) -> List[Dict[List[Any], OpQuantizationConfig]]: """ - Returns: List of patterns of layers/LayerFilterParams to fuse. + Returns: List of patterns of layers/LayerFilterParams to fuse and their mapping quantization config. """ - res = [] + + patterns = [] if self.tpc.fusing_patterns is None: - return res + return patterns + for p in self.tpc.fusing_patterns: + res = [] ops = [self.get_layers_by_opset(x) for x in p.operator_groups] res.extend(itertools.product(*ops)) - return [list(x) for x in res] + fused_op_quant_config = getattr(p, FUSED_OP_QUANT_CONFIG, None) + patterns.extend({FUSED_LAYER_PATTERN: list(x), FUSED_OP_QUANT_CONFIG: fused_op_quant_config} for x in res) + + return patterns def get_info(self) -> Dict[str, Any]: """ diff --git a/tests/keras_tests/non_parallel_tests/test_keras_tpc.py b/tests/keras_tests/non_parallel_tests/test_keras_tpc.py index e113a7937..995ce8b02 100644 --- a/tests/keras_tests/non_parallel_tests/test_keras_tpc.py +++ b/tests/keras_tests/non_parallel_tests/test_keras_tpc.py @@ -42,7 +42,7 @@ from keras import Input import model_compression_toolkit as mct -from model_compression_toolkit.constants import TENSORFLOW +from model_compression_toolkit.constants import TENSORFLOW, FUSED_LAYER_PATTERN, FUSED_OP_QUANT_CONFIG from model_compression_toolkit.target_platform_capabilities.constants import DEFAULT_TP_MODEL, IMX500_TP_MODEL, \ QNNPACK_TP_MODEL, TFLITE_TP_MODEL, KERNEL_ATTR, BIAS_ATTR, KERAS_KERNEL, BIAS, WEIGHTS_N_BITS from model_compression_toolkit.core.keras.keras_implementation import KerasImplementation @@ -250,7 +250,7 @@ def test_keras_fusing_patterns(self): fusings = hm_keras.get_fusing_patterns() self.assertEqual(len(fusings), 2) - p0, p1 = fusings[0], fusings[1] + p0, p1 = fusings[0].get(FUSED_LAYER_PATTERN), fusings[1].get(FUSED_LAYER_PATTERN) self.assertEqual(len(p0), 3) self.assertEqual(p0[0], Conv2D) diff --git a/tests/pytorch_tests/function_tests/test_pytorch_tpc.py b/tests/pytorch_tests/function_tests/test_pytorch_tpc.py index 1809b688a..1d3353d67 100644 --- a/tests/pytorch_tests/function_tests/test_pytorch_tpc.py +++ b/tests/pytorch_tests/function_tests/test_pytorch_tpc.py @@ -26,7 +26,7 @@ import model_compression_toolkit.target_platform_capabilities.schema.mct_current_schema as schema from model_compression_toolkit.core import MixedPrecisionQuantizationConfig from model_compression_toolkit.defaultdict import DefaultDict -from model_compression_toolkit.constants import PYTORCH +from model_compression_toolkit.constants import PYTORCH, FUSED_LAYER_PATTERN, FUSED_OP_QUANT_CONFIG from model_compression_toolkit.core.common import BaseNode from model_compression_toolkit.target_platform_capabilities.constants import DEFAULT_TP_MODEL, IMX500_TP_MODEL, \ TFLITE_TP_MODEL, QNNPACK_TP_MODEL, KERNEL_ATTR, WEIGHTS_N_BITS, PYTORCH_KERNEL, BIAS_ATTR, BIAS @@ -236,15 +236,15 @@ def test_pytorch_fusing_patterns(self): fusing_patterns=tuple(fusing_patterns), add_metadata=False) - hm_keras = FrameworkQuantizationCapabilities(hm) - with hm_keras: + hm_torch = FrameworkQuantizationCapabilities(hm) + with hm_torch: OperationsSetToLayers("opA", [torch.conv2d]) OperationsSetToLayers("opB", [torch.tanh]) OperationsSetToLayers("opC", [LayerFilterParams(torch.relu, Greater("max_value", 7), negative_slope=0)]) - fusings = hm_keras.get_fusing_patterns() + fusings = hm_torch.get_fusing_patterns() self.assertEqual(len(fusings), 2) - p0, p1 = fusings[0], fusings[1] + p0, p1 = fusings[0].get(FUSED_LAYER_PATTERN), fusings[1].get(FUSED_LAYER_PATTERN) self.assertEqual(len(p0), 3) self.assertEqual(p0[0], torch.conv2d) diff --git a/tests_pytest/common_tests/unit_tests/core/mixed_precision/resource_utilization_tools/test_resource_utilization_calculator.py b/tests_pytest/common_tests/unit_tests/core/mixed_precision/resource_utilization_tools/test_resource_utilization_calculator.py index 725b111e7..ce5e46039 100644 --- a/tests_pytest/common_tests/unit_tests/core/mixed_precision/resource_utilization_tools/test_resource_utilization_calculator.py +++ b/tests_pytest/common_tests/unit_tests/core/mixed_precision/resource_utilization_tools/test_resource_utilization_calculator.py @@ -20,7 +20,7 @@ import pytest from model_compression_toolkit.core.common.graph.base_graph import OutTensor -from model_compression_toolkit.constants import FLOAT_BITWIDTH +from model_compression_toolkit.constants import FLOAT_BITWIDTH, FUSED_LAYER_PATTERN, FUSED_OP_QUANT_CONFIG from model_compression_toolkit.core import ResourceUtilization from model_compression_toolkit.core.common import Graph from model_compression_toolkit.core.common.fusion.fusing_info import FusingInfo @@ -555,7 +555,8 @@ def test_compute_cuts_random_fusion_valid_utilization(self, seed, disable_quanti if i + fuse_len <= num_nodes: fused = tuple(nodes[j] for j in range(i, i + fuse_len)) fused_name = f"FusedNode_{'_'.join(n.name for n in fused)}" - fused_patterns.append([n.layer_class for n in fused]) + fused_pattern = {FUSED_LAYER_PATTERN: [n.layer_class for n in fused], FUSED_OP_QUANT_CONFIG: None} + fused_patterns.append(fused_pattern) fused_data[fused_name] = fused i += fuse_len else: diff --git a/tests_pytest/common_tests/unit_tests/core/test_fusion_info.py b/tests_pytest/common_tests/unit_tests/core/test_fusion_info.py index 2cf2d237a..e2ae32e11 100644 --- a/tests_pytest/common_tests/unit_tests/core/test_fusion_info.py +++ b/tests_pytest/common_tests/unit_tests/core/test_fusion_info.py @@ -19,6 +19,15 @@ from model_compression_toolkit.core.common.fusion.fusing_info import FusingInfoGenerator, FUSED_OP_ID_PREFIX, FusingInfo from model_compression_toolkit.target_platform_capabilities import FrameworkQuantizationCapabilities from model_compression_toolkit.core.common import BaseNode +from model_compression_toolkit.constants import FUSED_LAYER_PATTERN, FUSED_OP_QUANT_CONFIG +from mct_quantizers import QuantizationMethod + +from tests.common_tests.helpers.generate_test_tpc import generate_test_attr_configs, generate_test_op_qc + +# Setup TEST_QC and TEST_QCO for testing. +TEST_QC_1 = generate_test_op_qc(**generate_test_attr_configs(default_cfg_nbits=8, default_cfg_quantizatiom_method=QuantizationMethod.POWER_OF_TWO)) +TEST_QC_2 = generate_test_op_qc(**generate_test_attr_configs(default_cfg_nbits=4, default_cfg_quantizatiom_method=QuantizationMethod.LUT_POT_QUANTIZER)) +TEST_QC_3 = generate_test_op_qc(**generate_test_attr_configs(default_cfg_nbits=2, default_cfg_quantizatiom_method=QuantizationMethod.LUT_SYM_QUANTIZER)) class MockBaseNode: @@ -42,7 +51,8 @@ def fusing_patterns(): """ - Returns predefined fusing patterns: Conv2D + ReLU and Linear + Softmax. """ - return [["Conv2d", "ReLU"], ["Linear", "Softmax"]] + return [{FUSED_LAYER_PATTERN: ["Conv2d", "ReLU"], FUSED_OP_QUANT_CONFIG: None}, + {FUSED_LAYER_PATTERN: ["Linear", "Softmax"], FUSED_OP_QUANT_CONFIG: None}] @pytest.fixture @@ -219,3 +229,198 @@ def test_is_node_in_fused_op_returns_false_for_absent_node(mock_graph, fusing_in unrelated = MockBaseNode("unrelated") assert not fi.is_node_in_fused_op(unrelated) + +def create_mock_base_node(name: str, layer_class: str): + """ + Function for creating the mock nodes required for a simple neural network structure. + Enables node name, layer class, type, and type checking method. + """ + + dummy_initalize = {'framework_attr': {}, + 'input_shape': (), + 'output_shape': (), + 'weights': {}} + + real_node = BaseNode(name=name, layer_class=layer_class, **dummy_initalize) + + node = Mock(spec=real_node) + node.is_match_type = real_node.is_match_type + node.layer_class = layer_class + node.name = name + + return node + +@pytest.fixture +def fusing_patterns_with_qconfig(): + """ + - Returns predefined fusing patterns: Conv2D + ReLU and Conv2D + Tanh, Linear + Softmax. + """ + return [{FUSED_LAYER_PATTERN: ["Conv2d", "ReLU"], FUSED_OP_QUANT_CONFIG: TEST_QC_1}, + {FUSED_LAYER_PATTERN: ["Conv2d", "Tanh"], FUSED_OP_QUANT_CONFIG: None}, + {FUSED_LAYER_PATTERN: ["Conv2d", "BatchNorm2d", "ReLU6"], FUSED_OP_QUANT_CONFIG: TEST_QC_2}, + {FUSED_LAYER_PATTERN: ["Linear", "Softmax"], FUSED_OP_QUANT_CONFIG: TEST_QC_3 }] + +@pytest.fixture +def fusing_info_generator_with_qconfig(fusing_patterns_with_qconfig): + """ + Creates a FusingInfoGenerator using the fusing patterns. + """ + return FusingInfoGenerator(fusing_patterns_with_qconfig) + +@pytest.fixture +def mock_qconfig_set_nodes(): + """ + Creates mock nodes representing a simple neural network structure. + - Nodes: Conv2D, ReLU, Conv2D, Tanh, Linear, Softmax. + """ + mock_node_list = [] + mock_node_list.append(create_mock_base_node(name='conv_1', layer_class='Conv2d')) + mock_node_list.append(create_mock_base_node(name='relu_1', layer_class='ReLU')) + mock_node_list.append(create_mock_base_node(name='conv_2', layer_class='Conv2d')) + mock_node_list.append(create_mock_base_node(name='relu_2', layer_class='ReLU')) + mock_node_list.append(create_mock_base_node(name='conv_3', layer_class='Conv2d')) + mock_node_list.append(create_mock_base_node(name='tanh', layer_class='Tanh')) + mock_node_list.append(create_mock_base_node(name='conv_4', layer_class='Conv2d')) + mock_node_list.append(create_mock_base_node(name='bn', layer_class='BatchNorm2d')) + mock_node_list.append(create_mock_base_node(name='relu6', layer_class='ReLU6')) + mock_node_list.append(create_mock_base_node(name='linear', layer_class='Linear')) + mock_node_list.append(create_mock_base_node(name='softmax', layer_class='Softmax')) + + return mock_node_list + +@pytest.fixture +def mock_qconfig_set_graph(mock_qconfig_set_nodes): + """ + Creates a mock graph with topologically sorted nodes and defined connectivity. + - Implements `get_next_nodes` and `get_prev_nodes` to maintain linear order. + """ + mock_nodes = mock_qconfig_set_nodes + + graph = Mock() + graph.nodes.return_value = mock_nodes + graph.get_topo_sorted_nodes.return_value = mock_nodes + + adjacency = { + mock_nodes[0]: [mock_nodes[1]], # conv_1 -> relu_1 + mock_nodes[1]: [mock_nodes[2]], # relu_1 -> conv_2 + mock_nodes[2]: [mock_nodes[3]], # conv_2 -> relu_2 + mock_nodes[3]: [mock_nodes[4]], # relu_2 -> conv_3 + mock_nodes[4]: [mock_nodes[5]], # conv_3 -> tanh + mock_nodes[5]: [mock_nodes[6]], # tanh -> conv_4 + mock_nodes[6]: [mock_nodes[7]], # conv_4 -> bn + mock_nodes[7]: [mock_nodes[8]], # bn -> relu6 + mock_nodes[8]: [mock_nodes[9]], # relu6 -> linear + mock_nodes[9]: [mock_nodes[10]], # linear -> softmax + mock_nodes[10]: [] # softmax has no outputs + } + + reverse_adjacency = { + mock_nodes[0]: [], # conv_1 has no inputs + mock_nodes[1]: [mock_nodes[0]], # relu_1 <- conv_1 + mock_nodes[2]: [mock_nodes[1]], # conv_2 <- relu_1 + mock_nodes[3]: [mock_nodes[2]], # relu_2 <- conv_2 + mock_nodes[4]: [mock_nodes[3]], # conv_3 <- relu_2 + mock_nodes[5]: [mock_nodes[4]], # tanh <- conv_3 + mock_nodes[6]: [mock_nodes[5]], # conv_4 <- tanh + mock_nodes[7]: [mock_nodes[6]], # bn <- conv_4 + mock_nodes[8]: [mock_nodes[7]], # relu6 <- bn + mock_nodes[9]: [mock_nodes[8]], # linear <- relu6 + mock_nodes[10]: [mock_nodes[9]] # softmax <- linear + } + + graph.get_next_nodes.side_effect = lambda node: adjacency.get(node, []) + graph.get_prev_nodes.side_effect = lambda node: reverse_adjacency.get(node, []) + + return graph + + +def test_fusing_info_qconfig_mapping(mock_qconfig_set_graph, fusing_info_generator_with_qconfig): + """ + Tests that each node is correctly mapped to its fused quantization configs. + """ + fi = fusing_info_generator_with_qconfig.generate_fusing_info(mock_qconfig_set_graph) + fi_qconfig_map = fi.get_fusing_quantization_config_map() + + expected_op1_id = f"{FUSED_OP_ID_PREFIX}conv_1_relu_1" + expected_op2_id = f"{FUSED_OP_ID_PREFIX}conv_2_relu_2" + expected_op3_id = f"{FUSED_OP_ID_PREFIX}conv_3_tanh" + expected_op4_id = f"{FUSED_OP_ID_PREFIX}conv_4_bn_relu6" + expected_op5_id = f"{FUSED_OP_ID_PREFIX}linear_softmax" + + assert len(fi_qconfig_map) == 5 + assert fi_qconfig_map[expected_op1_id] == TEST_QC_1 + assert fi_qconfig_map[expected_op2_id] == TEST_QC_1 + assert fi_qconfig_map[expected_op3_id] == None + assert fi_qconfig_map[expected_op4_id] == TEST_QC_2 + assert fi_qconfig_map[expected_op5_id] == TEST_QC_3 + + +def test_add_fused_operation_adds_data_and_qconfig(mock_qconfig_set_graph, fusing_info_generator_with_qconfig): + """ + Tests whether the added node is correctly assigned the fused quantization config. + """ + + fi = fusing_info_generator_with_qconfig.generate_fusing_info(mock_qconfig_set_graph) + fi_qconfig_map = fi.get_fusing_quantization_config_map() + + ### Checking the number of mappings before addition + assert len(fi_qconfig_map) == 5 + + node1 = create_mock_base_node(name='conv_a', layer_class='Conv2d') + node2 = create_mock_base_node(name='relu_b', layer_class='ReLU') + + op_id = f"{FUSED_OP_ID_PREFIX}conv_a_relu_b" + fi.add_fused_operation(op_id, (node1, node2)) + fi_qconfig_map = fi.get_fusing_quantization_config_map() + + ### Checking the mapping information after addition + assert op_id in fi.get_all_fused_operations() + assert fi.get_fused_node_name("conv_a") == op_id + assert fi.get_fused_node_name("relu_b") == op_id + + assert len(fi_qconfig_map) == 6 + assert fi.get_fused_op_quantization_config(op_id) == TEST_QC_1 + + +def test_remove_fusing_data_and_qconfig(mock_qconfig_set_graph, fusing_info_generator_with_qconfig, mock_qconfig_set_nodes): + """ + Tests that the fused quantization config for the specified operation is removed from the map. + """ + + fi = fusing_info_generator_with_qconfig.generate_fusing_info(mock_qconfig_set_graph) + + conv_1_node, relu_1_node = mock_qconfig_set_nodes[0], mock_qconfig_set_nodes[1] ### Conv2D(conv_1) + ReLU(relu_1) pattern. + conv_2_node, relu_2_node = mock_qconfig_set_nodes[2], mock_qconfig_set_nodes[3] ### Conv2D(conv_2) + ReLU(relu_2) pattern. targeted for deletion. + conv_3_node, tanh_node = mock_qconfig_set_nodes[4], mock_qconfig_set_nodes[5] ### Conv2D(conv_3) + Tanh(tanh) pattern. + conv_4_node, bn_node, relu6_node = mock_qconfig_set_nodes[6], mock_qconfig_set_nodes[7], mock_qconfig_set_nodes[8] ### Conv2D(conv_4) + BatchNorm2d(bn) + ReLU6(relu6) pattern. + linear, softmax_node = mock_qconfig_set_nodes[9], mock_qconfig_set_nodes[10] ### Linear(linear) + Softmax(softmax) pattern. + + op1_id = f"{FUSED_OP_ID_PREFIX}conv_1_relu_1" + op2_id = f"{FUSED_OP_ID_PREFIX}conv_2_relu_2" ### Conv2D(conv_2) + ReLU(relu_2) pattern. targeted for deletion. + op3_id = f"{FUSED_OP_ID_PREFIX}conv_3_tanh" + op4_id = f"{FUSED_OP_ID_PREFIX}conv_4_bn_relu6" + op5_id = f"{FUSED_OP_ID_PREFIX}linear_softmax" + + ### Checking the mapping information before deletion. + assert len(fi.get_fusing_quantization_config_map()) == 5 + assert fi.get_fused_op_quantization_config(op2_id) == TEST_QC_1 + assert fi.get_fused_nodes(op2_id) == (conv_2_node, relu_2_node) + + fi.remove_fused_operation(op2_id) + + ### Checking the mapping information after deletion. + assert len(fi.get_fusing_quantization_config_map()) == 4 + assert fi.get_fused_op_quantization_config(op2_id) == None + assert fi.get_fused_nodes(op2_id) == None + + ### Checking that the mapping information for the other operations remains unchanged. + ### Confirm that the mapping information with the same structure as the deleted one still exists and has not been removed. + assert fi.get_fused_op_quantization_config(op1_id) == TEST_QC_1 + assert fi.get_fused_nodes(op1_id) == (conv_1_node, relu_1_node) + + assert fi.get_fused_op_quantization_config(op3_id) == None + assert fi.get_fused_nodes(op3_id) == (conv_3_node, tanh_node) + assert fi.get_fused_op_quantization_config(op4_id) == TEST_QC_2 + assert fi.get_fused_nodes(op4_id) == (conv_4_node, bn_node, relu6_node) + assert fi.get_fused_op_quantization_config(op5_id) == TEST_QC_3 + assert fi.get_fused_nodes(op5_id) == (linear, softmax_node) \ No newline at end of file diff --git a/tests_pytest/keras_tests/integration_tests/core/fusion/test_fusing_info_generator_keras.py b/tests_pytest/keras_tests/integration_tests/core/fusion/test_fusing_info_generator_keras.py index 9c9706fb7..ce5f4319e 100644 --- a/tests_pytest/keras_tests/integration_tests/core/fusion/test_fusing_info_generator_keras.py +++ b/tests_pytest/keras_tests/integration_tests/core/fusion/test_fusing_info_generator_keras.py @@ -23,6 +23,7 @@ from tests_pytest._test_util.graph_builder_utils import build_node from tests_pytest.keras_tests.keras_test_util.keras_test_mixin import KerasFwMixin import model_compression_toolkit.target_platform_capabilities.schema.mct_current_schema as schema +from model_compression_toolkit.constants import FUSED_LAYER_PATTERN, FUSED_OP_QUANT_CONFIG from tensorflow.keras import backend as K @@ -52,8 +53,12 @@ class TestFusingConvRelu(BaseTestFusingInfoGeneratorKeras): schema.OperatorsSet(name=schema.OperatorSetNames.RELU))) ] + expected_fusing_patterns = [ + {FUSED_LAYER_PATTERN: [fusing_patterns[0]], FUSED_OP_QUANT_CONFIG: None} + ] + expected_fi = FusingInfo( - fusing_patterns=fusing_patterns, + fusing_patterns=expected_fusing_patterns, fusing_data={ "FusedNode_conv1_conv2_collapsed_relu": ( build_node(name="conv1_conv2_collapsed"), @@ -91,8 +96,12 @@ class TestFusingAnyActKeras(BaseTestFusingInfoGeneratorKeras): schema.OperatorsSet(name="AnyAct"))) ] + expected_fusing_patterns = [ + {FUSED_LAYER_PATTERN: [fusing_patterns[0]], FUSED_OP_QUANT_CONFIG: None} + ] + expected_fi = FusingInfo( - fusing_patterns=fusing_patterns, + fusing_patterns=expected_fusing_patterns, fusing_data={ "FusedNode_conv1_conv2_collapsed_tanh": (build_node(name="conv1_conv2_collapsed"), @@ -144,8 +153,12 @@ class TestFusingConvReLUOnlyKeras(BaseTestFusingInfoGeneratorKeras): schema.OperatorsSet(name="AnyAct"))) ] + expected_fusing_patterns = [ + {FUSED_LAYER_PATTERN: [fusing_patterns[0]], FUSED_OP_QUANT_CONFIG: None} + ] + expected_fi = FusingInfo( - fusing_patterns=fusing_patterns, + fusing_patterns=expected_fusing_patterns, fusing_data={ "FusedNode_conv1_conv2_collapsed_tanh": (build_node(name="conv1_conv2_collapsed"), @@ -207,8 +220,17 @@ class TestFusingComplexPatternsKeras(BaseTestFusingInfoGeneratorKeras): schema.OperatorsSet(name=schema.OperatorSetNames.ADD))), ] + expected_fusing_patterns = [ + {FUSED_LAYER_PATTERN: [fusing_patterns[0]], FUSED_OP_QUANT_CONFIG: None}, + {FUSED_LAYER_PATTERN: [fusing_patterns[1]], FUSED_OP_QUANT_CONFIG: None}, + {FUSED_LAYER_PATTERN: [fusing_patterns[2]], FUSED_OP_QUANT_CONFIG: None}, + {FUSED_LAYER_PATTERN: [fusing_patterns[3]], FUSED_OP_QUANT_CONFIG: None}, + {FUSED_LAYER_PATTERN: [fusing_patterns[4]], FUSED_OP_QUANT_CONFIG: None}, + {FUSED_LAYER_PATTERN: [fusing_patterns[5]], FUSED_OP_QUANT_CONFIG: None} + ] + expected_fi = FusingInfo( - fusing_patterns=fusing_patterns, + fusing_patterns=expected_fusing_patterns, fusing_data={ "FusedNode_conv1_swish1_add": (build_node(name="conv1"), @@ -291,8 +313,12 @@ class TestFusingConvSwishWithMultiSuccessorsKeras(BaseTestFusingInfoGeneratorKer schema.OperatorsSet(name=schema.OperatorSetNames.SWISH))) ] + expected_fusing_patterns = [ + {FUSED_LAYER_PATTERN: [fusing_patterns[0]], FUSED_OP_QUANT_CONFIG: None} + ] + expected_fi = FusingInfo( - fusing_patterns=fusing_patterns, + fusing_patterns=expected_fusing_patterns, fusing_data={ "FusedNode_conv1_swish": ( build_node(name="conv1"), @@ -334,8 +360,12 @@ class TestFusingConvReluWithMultiPredecessorsKeras(BaseTestFusingInfoGeneratorKe schema.OperatorsSet(name=schema.OperatorSetNames.RELU))) ] + expected_fusing_patterns = [ + {FUSED_LAYER_PATTERN: [fusing_patterns[0]], FUSED_OP_QUANT_CONFIG: None} + ] + expected_fi = FusingInfo( - fusing_patterns=fusing_patterns, + fusing_patterns=expected_fusing_patterns, fusing_data={ "FusedNode_conv3_relu": ( build_node(name="conv3"), diff --git a/tests_pytest/pytorch_tests/integration_tests/core/fusion/test_fusing_info_generator_torch.py b/tests_pytest/pytorch_tests/integration_tests/core/fusion/test_fusing_info_generator_torch.py index e89b9e91a..798ef7206 100644 --- a/tests_pytest/pytorch_tests/integration_tests/core/fusion/test_fusing_info_generator_torch.py +++ b/tests_pytest/pytorch_tests/integration_tests/core/fusion/test_fusing_info_generator_torch.py @@ -26,7 +26,7 @@ import torch.nn as nn import model_compression_toolkit.target_platform_capabilities.schema.mct_current_schema as schema - +from model_compression_toolkit.constants import FUSED_LAYER_PATTERN, FUSED_OP_QUANT_CONFIG class BaseTestFusingInfoGeneratorPytorch(BaseFusingInfoGeneratorTest, TorchFwMixin): @@ -42,7 +42,6 @@ def _get_qc(self): - class TestFusingConvRelu(BaseTestFusingInfoGeneratorPytorch): last_node_activation_nbits, qcs = random_activation_configs() @@ -53,8 +52,12 @@ class TestFusingConvRelu(BaseTestFusingInfoGeneratorPytorch): schema.OperatorsSet(name=schema.OperatorSetNames.RELU))) ] + expected_fusing_patterns = [ + {FUSED_LAYER_PATTERN: [fusing_patterns[0]], FUSED_OP_QUANT_CONFIG: None} + ] + expected_fi = FusingInfo( - fusing_patterns=fusing_patterns, + fusing_patterns=expected_fusing_patterns, fusing_data={ "FusedNode_conv1_conv2_collapsed_relu": ( build_node(name="conv1_conv2_collapsed"), @@ -100,8 +103,12 @@ class TestFusingAnyAct(BaseTestFusingInfoGeneratorPytorch): schema.OperatorsSet(name="AnyAct"))) ] + expected_fusing_patterns = [ + {FUSED_LAYER_PATTERN: [fusing_patterns[0]], FUSED_OP_QUANT_CONFIG: None} + ] + expected_fi = FusingInfo( - fusing_patterns=fusing_patterns, + fusing_patterns=expected_fusing_patterns, fusing_data={ "FusedNode_conv1_conv2_collapsed_tanh": (build_node(name="conv1_conv2_collapsed"), @@ -167,8 +174,12 @@ class TestFusingConvReLUOnly(BaseTestFusingInfoGeneratorPytorch): schema.OperatorsSet(name="AnyAct"))) ] + expected_fusing_patterns = [ + {FUSED_LAYER_PATTERN: [fusing_patterns[0]], FUSED_OP_QUANT_CONFIG: None} + ] + expected_fi = FusingInfo( - fusing_patterns=fusing_patterns, + fusing_patterns=expected_fusing_patterns, fusing_data={ "FusedNode_conv1_conv2_collapsed_tanh": (build_node(name="conv1_conv2_collapsed"), build_node(name="tanh", qcs=qcs)), @@ -242,8 +253,17 @@ class TestFusingComplexPatterns(BaseTestFusingInfoGeneratorPytorch): schema.OperatorsSet(name=schema.OperatorSetNames.ADD))) ] + expected_fusing_patterns = [ + {FUSED_LAYER_PATTERN: [fusing_patterns[0]], FUSED_OP_QUANT_CONFIG: None}, + {FUSED_LAYER_PATTERN: [fusing_patterns[1]], FUSED_OP_QUANT_CONFIG: None}, + {FUSED_LAYER_PATTERN: [fusing_patterns[2]], FUSED_OP_QUANT_CONFIG: None}, + {FUSED_LAYER_PATTERN: [fusing_patterns[3]], FUSED_OP_QUANT_CONFIG: None}, + {FUSED_LAYER_PATTERN: [fusing_patterns[4]], FUSED_OP_QUANT_CONFIG: None}, + {FUSED_LAYER_PATTERN: [fusing_patterns[5]], FUSED_OP_QUANT_CONFIG: None} + ] + expected_fi = FusingInfo( - fusing_patterns=fusing_patterns, + fusing_patterns=expected_fusing_patterns, fusing_data={ "FusedNode_conv1_swish_add": ( build_node(name="conv1"), @@ -341,8 +361,12 @@ class TestFusingConvSwishWithMultiSuccessors(BaseTestFusingInfoGeneratorPytorch) schema.OperatorsSet(name=schema.OperatorSetNames.SWISH))) ] + expected_fusing_patterns = [ + {FUSED_LAYER_PATTERN: [fusing_patterns[0]], FUSED_OP_QUANT_CONFIG: None} + ] + expected_fi = FusingInfo( - fusing_patterns=fusing_patterns, + fusing_patterns=expected_fusing_patterns, fusing_data={ "FusedNode_conv1_swish": ( build_node(name="conv1"), @@ -390,8 +414,12 @@ class TestFusingConvReluWithMultiPredecessors(BaseTestFusingInfoGeneratorPytorch schema.OperatorsSet(name=schema.OperatorSetNames.RELU))) ] + expected_fusing_patterns = [ + {FUSED_LAYER_PATTERN: [fusing_patterns[0]], FUSED_OP_QUANT_CONFIG: None} + ] + expected_fi = FusingInfo( - fusing_patterns=fusing_patterns, + fusing_patterns=expected_fusing_patterns, fusing_data={ "FusedNode_conv3_relu": ( build_node(name="conv3"),