diff --git a/.github/workflows/run_keras_tests.yml b/.github/workflows/run_keras_tests.yml index 2f194c8c1..eb41d46af 100644 --- a/.github/workflows/run_keras_tests.yml +++ b/.github/workflows/run_keras_tests.yml @@ -27,11 +27,10 @@ jobs: pip install tensorflow==${{ inputs.tf-version }} sony-custom-layers pip install pytest pytest-mock pip check - - name: Run unittests - run: | - python -m unittest discover tests/keras_tests -v - - name: Run pytest run: | pytest tests_pytest/keras_tests + - name: Run unittests + run: | + python -m unittest discover tests/keras_tests -v diff --git a/.github/workflows/run_pytorch_tests.yml b/.github/workflows/run_pytorch_tests.yml index e3e450bff..0e1d470d3 100644 --- a/.github/workflows/run_pytorch_tests.yml +++ b/.github/workflows/run_pytorch_tests.yml @@ -33,10 +33,11 @@ jobs: pip install torch==${{ inputs.torch-version }} torchvision onnx onnxruntime "onnxruntime-extensions<0.14" pip install pytest pytest-mock pip check - - name: Run unittests - run: | - python -m unittest discover tests/pytorch_tests -v - name: Run pytest run: | pytest tests_pytest/pytorch_tests + - name: Run unittests + run: | + python -m unittest discover tests/pytorch_tests -v + diff --git a/model_compression_toolkit/core/common/fusion/fusing_info.py b/model_compression_toolkit/core/common/fusion/fusing_info.py new file mode 100644 index 000000000..f7241cf77 --- /dev/null +++ b/model_compression_toolkit/core/common/fusion/fusing_info.py @@ -0,0 +1,374 @@ +# Copyright 2025 Sony Semiconductor Israel, Inc. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + +from model_compression_toolkit.target_platform_capabilities import LayerFilterParams +from dataclasses import dataclass, field + +from typing import Optional, List, Dict, Any, Tuple +import copy + +# The prefix of each fused operator (the suffix is a combination of the +# nodes names that combine the fused operator). +FUSED_OP_ID_PREFIX = "FusedNode_" + + +@dataclass +class FusingInfo: + """ + This class manages information about fused operations in a graph. + + The key responsibility of this class is maintaining a mapping between original nodes + and their corresponding fused operation IDs. This mapping helps track which nodes + belong to fused operations and validate this info is correct after changes in the graph. + + The core structures maintained are: + - `fusing_data`: A dictionary mapping fused operation IDs to lists of nodes that belong to that operation. + - `node_to_fused_node_map`: A dictionary mapping each node name to the ID of the fused operation it belongs to. + + """ + fusing_patterns: any = None + fusing_data: Dict[str, Tuple['BaseNode']] = field(default_factory=dict) + node_to_fused_node_map: Dict[str, str] = field(init=False, default_factory=dict) + + def __post_init__(self): + """Validates and initializes mappings after dataclass instantiation.""" + for op_id, op_nodes in self.fusing_data.items(): + assert isinstance(op_id, str) and op_id.startswith(FUSED_OP_ID_PREFIX), f"Found invalid fused op id: {op_id}" + assert isinstance(op_nodes, tuple) and len(op_nodes) > 1, f"Found invalid fused op nodes: {op_nodes}" + + self._init_node_mapping() + + def _init_node_mapping(self) -> None: + """ + Init the node-to-fused-node mapping based on the initial fusing data. + """ + self.node_to_fused_node_map.clear() + for op_id, nodes in self.fusing_data.items(): + for node in nodes: + self.node_to_fused_node_map[node.name] = op_id + + def add_fused_operation(self, op_id: str, nodes: Tuple['BaseNode']) -> None: + """ + Add a new fused operation with the given ID and set of nodes. + + Args: + op_id (str): The identifier for the fused operation. + nodes (Tuple[BaseNode]): The tuple of nodes that form the fused operation. + + Raises: + ValueError: If the operation ID already exists. + """ + if op_id in self.fusing_data: + raise ValueError(f"Fused operation {op_id} already exists.") + assert isinstance(nodes, tuple), f"Expected nodes to be a tuple but its type is {type(nodes)}" + self.fusing_data[op_id] = nodes + # Update the mapping for these nodes + for node in nodes: + self.node_to_fused_node_map[node.name] = op_id + + def remove_fused_operation(self, op_id: str) -> None: + """ + Remove a fused operation by its ID. + + Args: + op_id (str): The identifier for the fused operation to remove. + + Raises: + ValueError: If the operation ID does not exist. + """ + if op_id not in self.fusing_data: + raise ValueError(f"Fused operation {op_id} does not exist.") + # Remove nodes from the mapping + nodes = self.fusing_data[op_id] + for node in nodes: + self.node_to_fused_node_map.pop(node.name, None) + del self.fusing_data[op_id] + + def get_fused_node_name(self, node_name: str) -> Optional[str]: + """ + Get the name of the fused node containing the given original node name. + + Args: + node_name: The name of a node from the original graph. + + Returns: + The name of the fused node containing this node, or None if not fused. + """ + return self.node_to_fused_node_map.get(node_name) + + def get_node_to_fused_node_map(self) -> Dict[str, str]: + """ + Retrieve a copy of the mapping from original node names to fused node names. + + Returns: + A dictionary mapping each original node name to its fused node name. + """ + return self.node_to_fused_node_map.copy() + + def get_fused_nodes(self, op_id: str) -> Optional[List['BaseNode']]: + """ + Retrieve the list of nodes for a given fused operation ID. + + Args: + op_id (str): The identifier for the fused operation. + + Returns: + Optional[List[BaseNode]]: The list of nodes for the operation, or None if not found. + """ + return self.fusing_data.get(op_id) + + def is_node_in_fused_op(self, node: 'BaseNode') -> bool: + """ + Check if a node is part of any fused operation. + + Args: + node (BaseNode): The node to check. + + Returns: + bool: True if the node is in any fused operation, False otherwise. + """ + return any(node in nodes for nodes in self.fusing_data.values()) + + def get_all_fused_operations(self) -> Dict[str, Tuple['BaseNode']]: + """ + Retrieve fused information. + + Returns: + Dict[str, List[BaseNode]]: The fusing data. + """ + return self.fusing_data + + + @staticmethod + def generate_fused_op_id(nodes: List['BaseNode']) -> str: + """ + Generates an identifier for a fused operation by concatenating + the names of the given nodes with a prefix. + + Args: + nodes (List[BaseNode]): A list of nodes to be fused. + + Returns: + str: An identifier string for the fused operation. + """ + id = FUSED_OP_ID_PREFIX + '_'.join([node.name for node in nodes]) + return id + + def validate(self, graph) -> None: + """ + Validate that the fusing information is consistent with the given graph and generation logic. + + This method performs the following checks: + 1. All nodes in the fusing data exist in the graph. + 2. Each fused sequence forms a valid linear chain in the graph: + - Each node (except the last) has exactly one successor, which is the next node in the sequence. + 3. No node is part of more than one fused operation. + 4. Each fused sequence matches a valid fusing pattern from the original set. + + Args: + graph: The computational graph to validate against. It is expected to have: + - `get_topo_sorted_nodes()`: Returns a list of nodes in topological order. + - `get_next_nodes(node)`: Returns a list of direct successor nodes. + + Raises: + ValueError: If any validation check fails. + """ + graph_nodes = set(graph.get_topo_sorted_nodes()) # Retrieve all nodes from the graph + all_fused_nodes = set() # Track all nodes used in fusions to ensure no overlap + + for op_id, nodes in self.fusing_data.items(): + # Check 1: Ensure all fused nodes exist in the graph + for node in nodes: + if node not in graph_nodes: + raise ValueError(f"Fused operation {op_id} contains node {node.name} not present in the graph.") + + # Check 2: Validate the fusion sequence forms a valid linear chain + for i in range(len(nodes) - 1): # Up to the second-to-last node + current_node = nodes[i] + next_node = nodes[i + 1] + successors = graph.get_next_nodes(current_node) + if len(successors) != 1 or successors[0] != next_node: + raise ValueError( + f"Fused operation {op_id} is not a valid linear chain: " + f"node {current_node.name} does not connect directly to {next_node.name} " + f"with exactly one successor (found successors: {[n.name for n in successors]})." + ) + + # Check 3: Ensure no node is reused across fusions + node_set = set(nodes) + overlap = node_set & all_fused_nodes + if overlap: + raise ValueError( + f"Fused operation {op_id} contains nodes already used in another fusion: " + f"{[node.name for node in overlap]}." + ) + all_fused_nodes.update(node_set) + + # Check 4: Ensure the sequence matches a valid fusing pattern + if not is_valid_fusion(self.fusing_patterns, nodes): + raise ValueError( + f"Fused operation {op_id} does not match any valid fusing pattern " + f"from {self.fusing_patterns}." + ) + + def is_nodes_eligible_to_be_fused(self, nodes: List['BaseNode']) -> bool: + """ + Check whether the given nodes are eligible to be fused based on predefined fusing patterns. + + This method retrieves the fusing patterns from `self.fqc` and verifies whether the + given sequence of nodes matches any of the valid patterns. + + Args: + nodes (List[BaseNode]): The list of nodes to check for fusion eligibility. + + Returns: + bool: True if the nodes can be fused according to fusing patterns, otherwise False. + """ + # If no fusing patterns are defined, fusion is not possible + if not self.fusing_patterns: + return False + + # Check if the provided nodes match a valid fusion pattern + return is_valid_fusion(fusing_patterns=self.fusing_patterns, nodes=nodes) + + def __repr__(self) -> str: + """ + Return a string representation of the fusing information. + """ + fusing_data_repr = "\n".join( + f" {op_id}: [{', '.join(node.name for node in nodes)}]" + for op_id, nodes in self.fusing_data.items() + ) + mapping_repr = ", ".join( + f"{node} -> {op_id}" for node, op_id in self.node_to_fused_node_map.items() + ) + return ( + f"FusingInfo(\n" + f" Total fused operations: {len(self.fusing_data)}\n" + f" Fusing Data:\n{fusing_data_repr}\n" + f" Node-to-Fused Mapping:\n {mapping_repr}\n" + f")" + ) + + +class FusingInfoGenerator: + def __init__(self, fusing_patterns): + self._fusing_patterns = fusing_patterns + + def generate_fusing_info(self, graph) -> FusingInfo: + """ + Generate fusing information based on the graph and fusing patterns. + + Args: + graph: The input graph to analyze, expected to have methods like + get_topo_sorted_nodes() and get_next_nodes(node). + + Returns: + A dictionary where keys are unique fusion identifiers (e.g., 'fused_op_0') + and values are lists of BaseNode objects representing nodes in that fusion. + + Notes: + - Assumes get_valid_fusing_patterns_for_node and is_valid_fusion functions are defined elsewhere. + - Nodes are processed in topological order to respect operation sequence. + - Fusions are linear sequences (each node has exactly one successor). + - Each node belongs to at most one fused operation. + """ + if not self._fusing_patterns: + return FusingInfo(fusing_patterns=self._fusing_patterns) + + # Find max fusion + max_layers_fusing = 0 if len(self._fusing_patterns) == 0 else max([len(fusing_pattern) for fusing_pattern in self._fusing_patterns]) + + # Travel along the graph to find layers for fusing + nodes = graph.get_topo_sorted_nodes() + + fusing_info: Dict[str, Tuple['BaseNode']] = {} + fused_nodes = [] # nodes that are participating in fusing + + for node in nodes: + # Skip if already in fusing + if node in fused_nodes: + continue + # Start fusing search + fusing_nodes = [] # nodes that are candidates for participating in fusing + patterns = copy.deepcopy(self._fusing_patterns) + next_nodes = [node] + for i in range(max_layers_fusing): + patterns = get_valid_fusing_patterns_for_node(patterns, next_nodes[0], i) + if len(patterns) == 0: # Give up if no more fusion pattern + break + fusing_nodes.append(next_nodes[0]) + next_nodes = graph.get_next_nodes(fusing_nodes[-1]) + if len(next_nodes) != 1: # Give up if node has more than one connection (not supported for fusion) + break + + # New fusion + if is_valid_fusion(self._fusing_patterns, fusing_nodes): + fused_op_id = FusingInfo.generate_fused_op_id(fusing_nodes) + assert fused_op_id not in fusing_info, f"{fused_op_id} is already in fusing info: {fusing_info}" + fusing_info[fused_op_id] = tuple(fusing_nodes) + fused_nodes.extend(fusing_nodes) + + return FusingInfo(fusing_data=fusing_info, fusing_patterns=self._fusing_patterns) + + +def get_valid_fusing_patterns_for_node(fusing_patterns: List[List[Any]], + node: 'BaseNode', + idx: int = 0) -> List[List[Any]]: + """ + Returns only the fusing patterns where a specific layer (at index idx) matches the given node — either by type or filter params. + + Args: + fusing_patterns: supported fusings + node: node to decide if it can be a part of fusion + idx: index of layer in the fusion + + Returns: + fusing_patterns after filtering non-relevant fusions + """ + valid_fusing_patterns = [] + for i, fusing_pattern in enumerate(fusing_patterns): + if idx < len(fusing_pattern): + if ((type(fusing_pattern[idx]) == LayerFilterParams and node.is_match_filter_params( + fusing_pattern[idx])) or node.is_match_type(fusing_pattern[idx])): + valid_fusing_patterns.append(fusing_pattern) + + # Return only valid patterns for this node + return valid_fusing_patterns + + +def is_valid_fusion(fusing_patterns: List[List[Any]], nodes: List['BaseNode']) -> bool: + """ + Check if the fusion is valid: exist in fusing_patterns + Args: + fusing_patterns: supported fusing patterns + nodes: nodes which are participating in fusion + Returns: + whether the fusion in valid + """ + fusion_depth = len(nodes) + if fusion_depth <= 1: + return False + for fusing_pattern in fusing_patterns: + if fusion_depth != len(fusing_pattern): + continue + counter = 0 + for i, layer in enumerate(fusing_pattern): + if (type(layer) == LayerFilterParams and nodes[i].is_match_filter_params(layer)) or \ + nodes[i].is_match_type(layer): + counter += 1 + if counter == fusion_depth: + return True + return False diff --git a/model_compression_toolkit/core/common/fusion/graph_fuser.py b/model_compression_toolkit/core/common/fusion/graph_fuser.py index fe6dcb007..ecc8fb4ef 100644 --- a/model_compression_toolkit/core/common/fusion/graph_fuser.py +++ b/model_compression_toolkit/core/common/fusion/graph_fuser.py @@ -1,4 +1,4 @@ -# Copyright 2024 Sony Semiconductor Israel, Inc. All rights reserved. +# Copyright 2025 Sony Semiconductor Israel, Inc. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -13,10 +13,13 @@ # limitations under the License. # ============================================================================== -from typing import Dict, List +import copy +from typing import List, Tuple -from model_compression_toolkit.core.common import Graph, BaseNode -from model_compression_toolkit.core.common.graph.base_graph import OutTensor +from model_compression_toolkit.core.common.fusion.fusing_info import FusingInfoGenerator +from model_compression_toolkit.core.common.graph.base_graph import Graph, BaseNode, OutTensor +from model_compression_toolkit.core.common.quantization.candidate_node_quantization_config import CandidateNodeQuantizationConfig +from itertools import product class FusedLayerType: @@ -27,35 +30,41 @@ class FusedLayerType: def __init__(self): self.__name__ = 'FusedLayer' - class GraphFuser: - - def create_fused_graph(self, graph: Graph) -> Dict[str, str]: + def apply_node_fusion(self, graph: Graph) -> Graph: """ - GraphFuser is responsible for fusing nodes in a networkx graph. - The fusion process involves: - 1. Creating new fused nodes to represent these groups. - 2. Updating the graph structure to replace the original nodes with fused nodes. - 3. Maintaining mapping of original node names to their fused node names. + Applies node fusion to the graph according the fusing_info it has. + + The fusion process includes: + 1. Generating new fused nodes to replace groups of original nodes. + 2. Updating the graph structure to replace those nodes with the fused representations. Args: - graph: Graph to fuse its nodes. + graph: The graph and its fusing metadata. Returns: - Mapping of original node names to their fused node names + The updated graph with fused nodes replacing the original node groups. """ - fused_nodes_mapping = {} - # Iterate through each group of nodes to be fused - for fused_nodes_list in graph.fused_nodes: - new_fused_node = self._create_fused_node(fused_nodes_list) - self._replace_nodes_with_fused_node(graph, fused_nodes_list, new_fused_node) - # Update the mapping to keep track of which original nodes are now part of which fused nodes - for node in fused_nodes_list: - fused_nodes_mapping[node.name] = new_fused_node.name - return fused_nodes_mapping + graph_copy = copy.deepcopy(graph) + expected_fusing_info = FusingInfoGenerator(graph_copy.fusing_info.fusing_patterns).generate_fusing_info(graph_copy) + + if expected_fusing_info != graph_copy.fusing_info: + raise ValueError( + f"Mismatch between expected and existing fusing information.\n" + f"Expected:\n{expected_fusing_info}\nExisting:\n{graph_copy.fusing_info}" + ) + + fused_operations = list(graph_copy.fusing_info.get_all_fused_operations().items()) + for fused_node_id, original_nodes in fused_operations: + fused_node = self._create_fused_node(fused_node_id, original_nodes) + graph_copy.fusing_info.remove_fused_operation(fused_node_id) + self._replace_nodes_with_fused_node(graph_copy, original_nodes, fused_node) + + return graph_copy + @staticmethod - def _create_fused_node(nodes: List[BaseNode]) -> BaseNode: + def _create_fused_node(fused_node_id: str, nodes: Tuple[BaseNode]) -> BaseNode: """ Create a new node that represents the fusion of the given nodes. @@ -67,22 +76,28 @@ def _create_fused_node(nodes: List[BaseNode]) -> BaseNode: """ # Create a new node with a name that reflects its components # Use the input shape of the first node and output shape of the last node - fused_node = BaseNode(name='FusedNode_' + '_'.join([node.name for node in nodes]), + # TODO: consider replacing the fused node with a sub-model to allow inference on it, etc. + fused_node = BaseNode(name=fused_node_id, framework_attr={}, input_shape=nodes[0].input_shape, output_shape=nodes[-1].output_shape, weights={}, layer_class=FusedLayerType) - # Preserve the final activation quantization configuration - # This is important for maintaining the correct behavior of the fused node + activation_cfgs = [c.activation_quantization_cfg for c in nodes[-1].candidates_quantization_cfg] + fused_node.candidates_quantization_cfg = [ + CandidateNodeQuantizationConfig(weights_quantization_cfg=None, activation_quantization_cfg=a) for a in + activation_cfgs] + + # Keep the final configurations if they were set already. + fused_node.final_weights_quantization_cfg = nodes[0].final_weights_quantization_cfg fused_node.final_activation_quantization_cfg = nodes[-1].final_activation_quantization_cfg return fused_node @staticmethod def _replace_nodes_with_fused_node(graph: Graph, - nodes_to_fuse: List[BaseNode], + nodes_to_fuse: Tuple[BaseNode], fused_node: BaseNode): """ Replace the specified nodes in the graph with a new fused node. @@ -118,6 +133,11 @@ def _replace_nodes_with_fused_node(graph: Graph, for next_node in subsequent_nodes: assert next_node in nodes_to_fuse # Ensure we're not removing edges outside the fusion graph.remove_edge(current_node, next_node) + # next_node can have more incoming edges from other nodes that are not + # in the fusion and we should remove them to: + in_edges = graph.incoming_edges(next_node) + for ie in in_edges: + graph.remove_edge(ie.source_node, next_node) # Handle the case where fused nodes are part of the graph's outputs graph_output_tensors = graph.get_outputs() @@ -136,3 +156,5 @@ def _replace_nodes_with_fused_node(graph: Graph, # Finally, add the new fused node to the graph graph.add_node(fused_node) + + diff --git a/model_compression_toolkit/core/common/fusion/layer_fusing.py b/model_compression_toolkit/core/common/fusion/layer_fusing.py deleted file mode 100644 index 1f2981eb3..000000000 --- a/model_compression_toolkit/core/common/fusion/layer_fusing.py +++ /dev/null @@ -1,131 +0,0 @@ -# Copyright 2022 Sony Semiconductor Israel, Inc. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================== -import copy -from typing import Any, List -from model_compression_toolkit.core.common.graph.base_graph import Graph -from model_compression_toolkit.core.common.graph.base_node import BaseNode -from model_compression_toolkit.target_platform_capabilities.targetplatform2framework.framework_quantization_capabilities import \ - FrameworkQuantizationCapabilities -from model_compression_toolkit.target_platform_capabilities.targetplatform2framework.layer_filter_params import LayerFilterParams - - -def filter_fusing_patterns(fusing_patterns: List[List[Any]], node: BaseNode, idx: int = 0) -> List[List[Any]]: - """ - Update relevant fusing patterns object if layer number 'idx' inside the fusion matches the node - Args: - fusing_patterns: supported fusings - node: node to decide if it can be a part of fusion - idx: index of layer in the fusion - Returns: - fusing_patterns after filtering non-relevant fusions - """ - valid_fusing_patterns = [] - for i, fusing_pattern in enumerate(fusing_patterns): - if idx < len(fusing_pattern): - if (type(fusing_pattern[idx]) == LayerFilterParams and node.is_match_filter_params(fusing_pattern[idx])) or \ - node.is_match_type(fusing_pattern[idx]): - valid_fusing_patterns.append(fusing_pattern) - - # Return only valid patterns for this node - return valid_fusing_patterns - - -def is_valid_fusion(fusing_patterns: List[List[Any]], nodes: List[BaseNode]) -> bool: - """ - Check if the fusion is valid: exist in fusing_patterns - Args: - fusing_patterns: supported fusing patterns - nodes: nodes which are participating in fusion - Returns: - whether the fusion in valid - """ - fusion_depth = len(nodes) - if fusion_depth <= 1: - return False - for fusing_pattern in fusing_patterns: - if fusion_depth != len(fusing_pattern): - continue - counter = 0 - for i, layer in enumerate(fusing_pattern): - if (type(layer) == LayerFilterParams and nodes[i].is_match_filter_params(layer)) or \ - nodes[i].is_match_type(layer): - counter += 1 - if counter == fusion_depth: - return True - return False - - -def disable_nodes_activation_quantization(nodes: List[BaseNode]): - """ - Disable activation for non-quantization needed due to fusion - Args: - nodes: nodes to update their activation quantization - """ - for node in nodes: - for qc in node.candidates_quantization_cfg: - qc.activation_quantization_cfg.enable_activation_quantization = False - - -def fusion(graph: Graph, fqc: FrameworkQuantizationCapabilities) -> Graph: - """ - Fusing defines a list of operators that should be combined and treated as a single operator, - hence no quantization is applied between them when they appear in the graph. - This function search and disable quantization for such patterns. - Args: - graph: Graph we apply the fusion on. - fqc: FrameworkQuantizationCapabilities object that describes the desired inference target platform (includes fusing patterns MCT should handle). - Returns: - Graph after applying fusion activation marking. - """ - fusing_patterns = fqc.get_fusing_patterns() - if len(fusing_patterns) == 0: - return graph - - # Find max fusion - max_layers_fusing = 0 if len(fusing_patterns) == 0 else max([len(fusing_pattern) for fusing_pattern in fusing_patterns]) - - - # -------------------------------- # - # Fusion algorithm - # -------------------------------- # - fused_graph = copy.deepcopy(graph) - - # Travel along the graph to find layers for fusing - nodes = fused_graph.get_topo_sorted_nodes() - fused_nodes = [] # nodes that are participating in fusing - for node in nodes: - # Skip if already in fusing - if node in fused_nodes: - continue - # Start fusing search - fusing_nodes = [] # nodes that are candidates for participating in fusing - patterns = copy.deepcopy(fusing_patterns) - next_nodes = [node] - for i in range(max_layers_fusing): - patterns = filter_fusing_patterns(patterns, next_nodes[0], i) - if len(patterns) == 0: # Give up if no more fusion pattern - break - fusing_nodes.append(next_nodes[0]) - next_nodes = fused_graph.get_next_nodes(fusing_nodes[-1]) - if len(next_nodes) != 1: # Give up if node has more than one connection (not supported for fusion) - break - - # New fusion: mark all nodes in the fusion except last one - if is_valid_fusion(fusing_patterns, fusing_nodes): - fused_nodes.extend(fusing_nodes) - disable_nodes_activation_quantization(fusing_nodes[:-1]) - fused_graph.update_fused_nodes(fusing_nodes) - - return fused_graph diff --git a/model_compression_toolkit/core/common/graph/base_graph.py b/model_compression_toolkit/core/common/graph/base_graph.py index c94b7db91..79afb2fd7 100644 --- a/model_compression_toolkit/core/common/graph/base_graph.py +++ b/model_compression_toolkit/core/common/graph/base_graph.py @@ -15,7 +15,8 @@ from collections import namedtuple from copy import copy, deepcopy -from typing import List, Tuple, Any +from functools import wraps +from typing import List, Tuple, Any, Callable import networkx as nx import numpy as np @@ -23,6 +24,7 @@ from networkx.algorithms.dag import topological_sort from model_compression_toolkit.core.common.framework_info import FrameworkInfo +from model_compression_toolkit.core.common.fusion.fusing_info import FusingInfo from model_compression_toolkit.core.common.graph.edge import EDGE_SINK_INDEX, EDGE_SOURCE_INDEX from model_compression_toolkit.core.common.graph.edge import Edge, convert_to_edge from model_compression_toolkit.core.common.graph.graph_searches import GraphSearches @@ -36,6 +38,27 @@ from model_compression_toolkit.target_platform_capabilities.targetplatform2framework.framework_quantization_capabilities import \ FrameworkQuantizationCapabilities + +def validate_graph_after_change(method: Callable) -> Callable: + """ + Decorator for graph-mutating methods. After the decorated method executes, + this decorator calls `self.validate()` to ensure the graph remains in a valid state. + + Args: + method: The graph-modifying method to wrap. + + Returns: + A wrapped method that validates the graph after execution. + """ + @wraps(method) + def wrapper(self, *args, **kwargs): + result = method(self, *args, **kwargs) + if not self.skip_validation_check: + self.validate() # calls Graph.validate(). Ensure graph consistency after changes. + return result + return wrapper + + OutTensor = namedtuple('OutTensor', 'node node_out_index') @@ -63,6 +86,11 @@ def __init__(self, """ super().__init__(**attr) + + # This must be set first to ensure it's available when validation runs during graph creation. + self._skip_validation_check = False + self._fusing_info = FusingInfo() + self.name = name self.input_nodes = input_nodes self.output_nodes = output_nodes @@ -75,7 +103,25 @@ def __init__(self, **e.get_attributes()) self.user_info = UserInformation() self.fw_info = fw_info - self.fused_nodes = [] + + @property + def skip_validation_check(self) -> bool: + return self._skip_validation_check + + @skip_validation_check.setter + def skip_validation_check(self, value: bool): + if not isinstance(value, bool): + raise ValueError("skip_validation_check must be a boolean.") + self._skip_validation_check = value + + @property + def fusing_info(self) -> FusingInfo: + return self._fusing_info + + @fusing_info.setter + @validate_graph_after_change + def fusing_info(self, fusing_info: FusingInfo): + self._fusing_info = fusing_info def set_fw_info(self, fw_info: FrameworkInfo): @@ -139,6 +185,7 @@ def get_outputs(self) -> List[OutTensor]: return self.output_nodes + @validate_graph_after_change def set_inputs(self, input_nodes: List[BaseNode]): """ @@ -149,6 +196,7 @@ def set_inputs(self, self.input_nodes = input_nodes + @validate_graph_after_change def set_outputs(self, output_nodes: List[OutTensor]): """ @@ -321,6 +369,7 @@ def get_prev_nodes(self, sort_attr = None return [edges_list.source_node for edges_list in self.incoming_edges(node_obj, sort_by_attr=sort_attr)] + @validate_graph_after_change def reconnect_out_edges(self, current_node: BaseNode, new_node: BaseNode): @@ -337,6 +386,7 @@ def reconnect_out_edges(self, self.add_edge(new_node, oe.sink_node, **oe.get_attributes()) self.remove_edge(current_node, oe.sink_node) + @validate_graph_after_change def reconnect_in_edges(self, current_node: BaseNode, new_node: BaseNode): @@ -353,6 +403,7 @@ def reconnect_in_edges(self, self.add_edge(ie.source_node, new_node, **ie.get_attributes()) self.remove_edge(ie.source_node, current_node) + @validate_graph_after_change def add_node_with_in_edges(self, new_node: BaseNode, input_nodes: List[BaseNode], input_nodes_output_index: List[int] = []): """ @@ -378,6 +429,7 @@ def add_node_with_in_edges(self, new_node: BaseNode, input_nodes: List[BaseNode] for sink_index, (in_node, source_index) in enumerate(zip(input_nodes, input_nodes_output_index)): self.add_edge(in_node, new_node, source_index=source_index, sink_index=sink_index) + @validate_graph_after_change def replace_output_node(self, current_node: BaseNode, new_node: BaseNode): @@ -400,6 +452,7 @@ def replace_output_node(self, new_graph_outputs[graph_ot_index] = OutTensor(new_node, ot.node_out_index) self.set_outputs(new_graph_outputs) + @validate_graph_after_change def replace_input_node(self, current_node: BaseNode, new_node: BaseNode): @@ -424,6 +477,7 @@ def replace_input_node(self, new_graph_inputs.append(new_node) self.set_inputs(new_graph_inputs) + @validate_graph_after_change def remove_node(self, node_to_remove: BaseNode, new_graph_inputs: List[BaseNode] = None, @@ -713,16 +767,6 @@ def retrieve_preserved_quantization_node(self, node: BaseNode) -> BaseNode: node = prev_nodes[0] return node - def update_fused_nodes(self, fusion: List[Any]): - """ - Updates the graphs fusions list with a new list of nodes that have been fused. - - Args: - fusion: A list of nodes that have been fused. - - """ - self.fused_nodes.append(fusion) - def has_any_configurable_activation(self) -> bool: """ Checks whether any node in the graph has a configurable activation quantization. @@ -742,6 +786,7 @@ def has_any_configurable_weights(self): return any([n.has_any_configurable_weight() for n in self.nodes]) + @validate_graph_after_change def replace_node(self, node_to_replace: BaseNode, new_node: BaseNode): """ Replaces a node in the graph with a new node. @@ -867,4 +912,36 @@ def _find_intermediate_and_exit_nodes(self, entry_node: BaseNode, fw_impl: Any) return intermediate_nodes, next_node + def disable_fused_nodes_activation_quantization(self): + """ + Disable activation quantization for all nodes in fused operations, + except for the last node in each fused group. + """ + nodes_to_disable = [node for nodes in self.fusing_info.get_all_fused_operations().values() for node in nodes[:-1]] + for node in nodes_to_disable: + for qc in node.candidates_quantization_cfg: + qc.activation_quantization_cfg.enable_activation_quantization = False + + def validate(self): + """ + Validate that the current state of the graph is consistent with + the fusing information (e.g., no missing or incorrect fused node mapping). + Returns: + The result of the FusingInfo validation logic (typically None or raises error). + """ + return self.fusing_info.validate(self) + + @validate_graph_after_change + def add_edge(self, *args, **kwargs): + """ + Wrap networkx functions (that modifies the graph) with our validate decorator. + """ + return super().add_edge(*args, **kwargs) + + @validate_graph_after_change + def remove_edge(self, *args, **kwargs): + """ + Wrap networkx functions (that modifies the graph) with our validate decorator. + """ + return super().remove_edge(*args, **kwargs) \ No newline at end of file diff --git a/model_compression_toolkit/core/common/mixed_precision/mixed_precision_search_facade.py b/model_compression_toolkit/core/common/mixed_precision/mixed_precision_search_facade.py index 4189cc37a..dc40670eb 100644 --- a/model_compression_toolkit/core/common/mixed_precision/mixed_precision_search_facade.py +++ b/model_compression_toolkit/core/common/mixed_precision/mixed_precision_search_facade.py @@ -65,6 +65,7 @@ def search_bit_width(graph: Graph, bit-width index on the node). """ + assert target_resource_utilization.is_any_restricted() # If we only run weights compression with MP than no need to consider activation quantization when computing the @@ -88,6 +89,11 @@ def search_bit_width(graph: Graph, if search_method != BitWidthSearchMethod.INTEGER_PROGRAMMING: raise NotImplementedError() + # Validation is skipped during the mixed-precision search configuration because fusing information is not + # relevant for the virtual graph. Therefore, validation checks are disabled before the search begins and + # re-enabled once it completes. + graph.skip_validation_check = True + # Search manager and LP are highly coupled, so LP search method was moved inside search manager. search_manager = MixedPrecisionSearchManager(graph, fw_info, @@ -96,6 +102,8 @@ def search_bit_width(graph: Graph, target_resource_utilization) result_bit_cfg = search_manager.search() + graph.skip_validation_check = False + if mp_config.refine_mp_solution: result_bit_cfg = greedy_solution_refinement_procedure(result_bit_cfg, search_manager, target_resource_utilization) diff --git a/model_compression_toolkit/core/common/quantization/candidate_node_quantization_config.py b/model_compression_toolkit/core/common/quantization/candidate_node_quantization_config.py index 0fe5c7c94..a1833f24b 100644 --- a/model_compression_toolkit/core/common/quantization/candidate_node_quantization_config.py +++ b/model_compression_toolkit/core/common/quantization/candidate_node_quantization_config.py @@ -71,11 +71,13 @@ def __init__(self, if weights_quantization_cfg is not None: self.weights_quantization_cfg = weights_quantization_cfg - else: - if any(v is None for v in (qc, op_cfg, node_attrs_list)): # pragma: no cover - Logger.critical("Missing required arguments to initialize a node weights quantization configuration. " - "Ensure QuantizationConfig, OpQuantizationConfig, weights quantization function, " - "parameters function, and weights attribute quantization config are provided.") - self.weights_quantization_cfg = NodeWeightsQuantizationConfig(qc=qc, op_cfg=op_cfg, + elif all(v is not None for v in (qc, op_cfg, node_attrs_list)): + self.weights_quantization_cfg = NodeWeightsQuantizationConfig(qc=qc, + op_cfg=op_cfg, weights_channels_axis=weights_channels_axis, node_attrs_list=node_attrs_list) + else: + self.weights_quantization_cfg = None + Logger.debug("Setting weights quantization config as None during CandidateNodeQuantizationConfig creation." + "Notice, this should happen only for FLN nodes.") + diff --git a/model_compression_toolkit/core/common/substitutions/batchnorm_reconstruction.py b/model_compression_toolkit/core/common/substitutions/batchnorm_reconstruction.py index 1f17263a4..087c9e0dc 100644 --- a/model_compression_toolkit/core/common/substitutions/batchnorm_reconstruction.py +++ b/model_compression_toolkit/core/common/substitutions/batchnorm_reconstruction.py @@ -19,11 +19,11 @@ import numpy as np +from model_compression_toolkit.core.common import Graph from model_compression_toolkit.core.common.quantization.quantization_config import QuantizationConfig from model_compression_toolkit.core import common from model_compression_toolkit.core.common.quantization.node_quantization_config import WeightsAttrQuantizationConfig from model_compression_toolkit.logger import Logger -from model_compression_toolkit.core.common.graph.base_graph import Graph from model_compression_toolkit.core.common.graph.base_node import BaseNode from model_compression_toolkit.core.common.graph.graph_matchers import NodeOperationMatcher from mct_quantizers import QuantizationMethod @@ -143,6 +143,21 @@ def substitute(self, AttributeQuantizationConfig( enable_weights_quantization=False))) + # Check if the source node was part of a fusion. If so, there are two cases: + # either this is no longer a fusion, and the fusion info should be updated by removing + # the current info, or this creates a new fusion and the old pattern should be + # replaced with the new one. + fi = graph.fusing_info + fused_op = fi.get_fused_node_name(source_node.name) + if fused_op: + fused_nodes = list(fi.get_fused_nodes(fused_op)) + assert source_node in fused_nodes + fused_nodes.insert(fused_nodes.index(source_node)+1, bn_node) + fi.remove_fused_operation(fused_op) + if fi.is_nodes_eligible_to_be_fused(fused_nodes): + op_id = fi.generate_fused_op_id(fused_nodes) + fi.add_fused_operation(op_id, tuple(fused_nodes)) + graph.reconnect_out_edges(current_node=source_node, new_node=bn_node) graph.replace_output_node(current_node=source_node, new_node=bn_node) graph.add_node_with_in_edges(bn_node, [source_node]) diff --git a/model_compression_toolkit/core/graph_prep_runner.py b/model_compression_toolkit/core/graph_prep_runner.py index 78d543f15..15006d2da 100644 --- a/model_compression_toolkit/core/graph_prep_runner.py +++ b/model_compression_toolkit/core/graph_prep_runner.py @@ -18,7 +18,7 @@ from model_compression_toolkit.core.common import FrameworkInfo from model_compression_toolkit.core.common.framework_implementation import FrameworkImplementation -from model_compression_toolkit.core.common.fusion.layer_fusing import fusion +from model_compression_toolkit.core.common.fusion.fusing_info import FusingInfoGenerator from model_compression_toolkit.core.common.graph.base_graph import Graph from model_compression_toolkit.core.common.quantization.bit_width_config import BitWidthConfig from model_compression_toolkit.core.common.quantization.filter_nodes_candidates import filter_nodes_candidates @@ -136,6 +136,7 @@ def get_finalized_graph(initial_graph: Graph, node.prior_info = fw_impl.get_node_prior_info(node=node, fw_info=fw_info, graph=graph) + ################################################## # Graph substitution (pre statistics collection) ################################################## @@ -161,7 +162,9 @@ def get_finalized_graph(initial_graph: Graph, ###################################### # Layer fusing ###################################### - transformed_graph = fusion(transformed_graph, fqc) + fusing_info = FusingInfoGenerator(fqc.get_fusing_patterns()).generate_fusing_info(transformed_graph) + transformed_graph.fusing_info = fusing_info + transformed_graph.disable_fused_nodes_activation_quantization() ###################################### # Channel equalization diff --git a/model_compression_toolkit/core/runner.py b/model_compression_toolkit/core/runner.py index 3cfe0810d..f6e8e5dde 100644 --- a/model_compression_toolkit/core/runner.py +++ b/model_compression_toolkit/core/runner.py @@ -184,15 +184,14 @@ def core_runner(in_model: Any, scheduler_info = None if core_config.debug_config.simulate_scheduler: - graph_to_fuse = copy.deepcopy(tg) - fused_nodes_mapping = GraphFuser().create_fused_graph(graph_to_fuse) - memory_graph = MemoryGraph(graph_to_fuse) + fused_graph = GraphFuser().apply_node_fusion(tg) + memory_graph = MemoryGraph(fused_graph) schedule, max_cut, cuts = compute_graph_max_cut(memory_graph) scheduler_info = SchedulerInfo( operators_scheduling=schedule, max_cut=float(max_cut), cuts=cuts, - fused_nodes_mapping=fused_nodes_mapping + fused_nodes_mapping=tg.fusing_info.get_node_to_fused_node_map() ) return tg, bit_widths_config, hessian_info_service, scheduler_info diff --git a/tests/keras_tests/feature_networks_tests/feature_networks/compute_max_cut_test.py b/tests/keras_tests/feature_networks_tests/feature_networks/compute_max_cut_test.py index 64d7c81e9..f37be8d26 100644 --- a/tests/keras_tests/feature_networks_tests/feature_networks/compute_max_cut_test.py +++ b/tests/keras_tests/feature_networks_tests/feature_networks/compute_max_cut_test.py @@ -18,6 +18,7 @@ from mct_quantizers.keras.metadata import get_metadata from model_compression_toolkit.constants import TENSORFLOW +from model_compression_toolkit.core.common.fusion.fusing_info import FUSED_OP_ID_PREFIX from model_compression_toolkit.target_platform_capabilities.constants import IMX500_TP_MODEL from tests.common_tests.helpers.tpcs_for_tests.v2.tpc import get_tpc from tests.keras_tests.feature_networks_tests.base_keras_feature_test import BaseKerasFeatureNetworkTest @@ -51,16 +52,16 @@ def compare(self, quantized_model, float_model, input_x=None, quantization_info= _metadata = get_metadata(quantized_model) self.unit_test.assertEqual(_metadata['scheduling_info']['operators_scheduling'], ['InputLayer:input_layer', - 'FusedLayerType:FusedNode_conv2d_1_bn_relu_1', - 'FusedLayerType:FusedNode_conv2d_2_bn_relu_2', + f'FusedLayerType:{FUSED_OP_ID_PREFIX}conv2d_1_bn_relu_1', + f'FusedLayerType:{FUSED_OP_ID_PREFIX}conv2d_2_bn_relu_2', 'Add:add_layer']) self.unit_test.assertEqual(_metadata['scheduling_info']['max_cut'], 256 * 3) expected_fused_nodes_mapping = { - 'conv2d_1_bn': 'FusedNode_conv2d_1_bn_relu_1', - 'relu_1': 'FusedNode_conv2d_1_bn_relu_1', - 'conv2d_2_bn': 'FusedNode_conv2d_2_bn_relu_2', - 'relu_2': 'FusedNode_conv2d_2_bn_relu_2' + 'conv2d_1_bn': f'{FUSED_OP_ID_PREFIX}conv2d_1_bn_relu_1', + 'relu_1': f'{FUSED_OP_ID_PREFIX}conv2d_1_bn_relu_1', + 'conv2d_2_bn': f'{FUSED_OP_ID_PREFIX}conv2d_2_bn_relu_2', + 'relu_2': f'{FUSED_OP_ID_PREFIX}conv2d_2_bn_relu_2' } self.unit_test.assertEqual(_metadata['scheduling_info']['fused_nodes_mapping'], expected_fused_nodes_mapping) diff --git a/tests/keras_tests/function_tests/test_activation_weights_composition_substitution.py b/tests/keras_tests/function_tests/test_activation_weights_composition_substitution.py index be3cd7a21..681e456df 100644 --- a/tests/keras_tests/function_tests/test_activation_weights_composition_substitution.py +++ b/tests/keras_tests/function_tests/test_activation_weights_composition_substitution.py @@ -20,6 +20,7 @@ from packaging import version import tensorflow as tf +from model_compression_toolkit.core.common.fusion.fusing_info import FusingInfoGenerator from model_compression_toolkit.core.common.quantization.quantization_config import CustomOpsetLayers from model_compression_toolkit.target_platform_capabilities.targetplatform2framework.attach2keras import \ AttachTpcToKeras @@ -31,7 +32,6 @@ from keras.layers import Conv2D, Conv2DTranspose, DepthwiseConv2D, Dense, BatchNormalization, ReLU, Input, Add, InputLayer import numpy as np -from model_compression_toolkit.core.common.fusion.layer_fusing import fusion from model_compression_toolkit.core.common.graph.virtual_activation_weights_node import VirtualSplitActivationNode, \ VirtualActivationWeightsNode, VirtualSplitWeightsNode from model_compression_toolkit.core.common.quantization.filter_nodes_candidates import filter_nodes_candidates @@ -123,7 +123,11 @@ def prepare_graph(in_model, keras_impl, mixed_precision_candidates_list, base_co graph = set_quantization_configuration_to_graph(graph=graph, quant_config=qc, mixed_precision_enable=True) - graph = fusion(graph, fqc) + + fusing_info = FusingInfoGenerator(fqc.get_fusing_patterns()).generate_fusing_info(graph) + graph.fusing_info = fusing_info + graph.disable_fused_nodes_activation_quantization() + graph = filter_nodes_candidates(graph) return graph @@ -168,10 +172,17 @@ def test_two_conv_net_compose_after_split(self): mixed_precision_candidates_list=_get_base_mp_nbits_candidates(), base_config=base_config, default_config=default_config) + # Validation is skipped because fusing information is not relevant for the virtual graph. + # Therefore, validation checks are disabled before the virtual graph substitution and + # re-enabled once it completes. + graph.skip_validation_check = True + # Nodes split and composition substitution split_graph = substitute(copy.deepcopy(graph), [WeightsActivationSplit()]) v_graph = substitute(copy.deepcopy(split_graph), [VirtualActivationWeightsComposition()]) + graph.skip_validation_check = False + self._verify_two_conv_with_split_test(graph, v_graph, 9, 9) def test_two_conv_net_compose_after_split_weights_only(self): @@ -184,10 +195,17 @@ def test_two_conv_net_compose_after_split_weights_only(self): mixed_precision_candidates_list=_get_base_mp_nbits_candidates(), base_config=base_config, default_config=default_config) + # Validation is skipped because fusing information is not relevant for the virtual graph. + # Therefore, validation checks are disabled before the virtual graph substitution and + # re-enabled once it completes. + graph.skip_validation_check = True + # Nodes split and composition substitution split_graph = substitute(copy.deepcopy(graph), [WeightsActivationSplit()]) v_graph = substitute(copy.deepcopy(split_graph), [VirtualActivationWeightsComposition()]) + graph.skip_validation_check = False + self._verify_two_conv_with_split_test(graph, v_graph, 3, 3) def test_two_conv_net_compose_after_split_activation_only(self): @@ -201,9 +219,17 @@ def test_two_conv_net_compose_after_split_activation_only(self): mixed_precision_candidates_list=_get_base_mp_nbits_candidates(), base_config=base_config, default_config=default_config) + # Validation is skipped because fusing information is not relevant for the virtual graph. + # Therefore, validation checks are disabled before the virtual graph substitution and + # re-enabled once it completes. + graph.skip_validation_check = True + # Nodes split and composition substitution split_graph = substitute(copy.deepcopy(graph), [WeightsActivationSplit()]) v_graph = substitute(copy.deepcopy(split_graph), [VirtualActivationWeightsComposition()]) + + graph.skip_validation_check = False + self._verify_two_conv_with_split_test(graph, v_graph, 3, 3) def test_all_weights_layers_composition(self): @@ -216,10 +242,17 @@ def test_all_weights_layers_composition(self): base_config=base_config, default_config=default_config) + # Validation is skipped because fusing information is not relevant for the virtual graph. + # Therefore, validation checks are disabled before the virtual graph substitution and + # re-enabled once it completes. + graph.skip_validation_check = True + # Nodes split and composition substitution split_graph = substitute(copy.deepcopy(graph), [WeightsActivationSplit()]) v_graph = substitute(copy.deepcopy(split_graph), [VirtualActivationWeightsComposition()]) + graph.skip_validation_check = False + assert split_graph is not graph self.assertTrue(len(v_graph.nodes) == 8) self.assertTrue(len([n for n in v_graph.nodes if isinstance(n, VirtualActivationWeightsNode)]) == 5) @@ -254,9 +287,16 @@ def test_multiple_output_activation(self): mixed_precision_candidates_list=_get_base_mp_nbits_candidates(), base_config=base_config, default_config=default_config) + # Validation is skipped because fusing information is not relevant for the virtual graph. + # Therefore, validation checks are disabled before the virtual graph substitution and + # re-enabled once it completes. + graph.skip_validation_check = True + split_graph = substitute(copy.deepcopy(graph), [WeightsActivationSplit()]) v_graph = substitute(copy.deepcopy(split_graph), [VirtualActivationWeightsComposition()]) + graph.skip_validation_check = False + # Since the only activation before the convolutions is the Input layer activation, and it goes to both # convolutions (the input node has multiple output edges) no composition should be made. self.assertTrue(len(v_graph.nodes) == len(split_graph.nodes)) diff --git a/tests/keras_tests/function_tests/test_cfg_candidates_filter.py b/tests/keras_tests/function_tests/test_cfg_candidates_filter.py index 272b0683e..f54f75382 100644 --- a/tests/keras_tests/function_tests/test_cfg_candidates_filter.py +++ b/tests/keras_tests/function_tests/test_cfg_candidates_filter.py @@ -20,13 +20,13 @@ import model_compression_toolkit as mct from model_compression_toolkit.constants import FLOAT_BITWIDTH from model_compression_toolkit.core import CustomOpsetLayers +from model_compression_toolkit.core.common.fusion.fusing_info import FusingInfoGenerator from model_compression_toolkit.core.common.quantization.filter_nodes_candidates import filter_nodes_candidates from model_compression_toolkit.core.common.quantization.set_node_quantization_config import \ set_quantization_configuration_to_graph from model_compression_toolkit.core.keras.constants import KERNEL from model_compression_toolkit.core.keras.default_framework_info import DEFAULT_KERAS_INFO from model_compression_toolkit.core.keras.keras_implementation import KerasImplementation -from model_compression_toolkit.core.common.fusion.layer_fusing import fusion from model_compression_toolkit.target_platform_capabilities.targetplatform2framework.attach2keras import \ AttachTpcToKeras from tests.common_tests.helpers.generate_test_tpc import generate_test_attr_configs, generate_test_op_qc @@ -58,7 +58,10 @@ def prepare_graph(in_model, base_config, default_config, bitwidth_candidates): graph = set_quantization_configuration_to_graph(graph=graph, quant_config=mct.core.QuantizationConfig(), mixed_precision_enable=True) - graph = fusion(graph, fqc) + + fusing_info = FusingInfoGenerator(fqc.get_fusing_patterns()).generate_fusing_info(graph) + graph.fusing_info = fusing_info + graph.disable_fused_nodes_activation_quantization() return graph diff --git a/tests/keras_tests/function_tests/test_layer_fusing.py b/tests/keras_tests/function_tests/test_layer_fusing.py deleted file mode 100644 index 9c14045e5..000000000 --- a/tests/keras_tests/function_tests/test_layer_fusing.py +++ /dev/null @@ -1,255 +0,0 @@ -import unittest -import numpy as np -import tensorflow as tf - -import model_compression_toolkit.target_platform_capabilities.schema.mct_current_schema as schema -from model_compression_toolkit.core import DEFAULTCONFIG, QuantizationConfig -from model_compression_toolkit.core.common.fusion.layer_fusing import fusion -from model_compression_toolkit.core.keras.default_framework_info import DEFAULT_KERAS_INFO -from model_compression_toolkit.core.keras.keras_implementation import KerasImplementation -from model_compression_toolkit.core.common.quantization.quantization_config import CustomOpsetLayers -from model_compression_toolkit.target_platform_capabilities.targetplatform2framework import LayerFilterParams -from model_compression_toolkit.target_platform_capabilities.targetplatform2framework.attach2keras import \ - AttachTpcToKeras -from model_compression_toolkit.target_platform_capabilities.tpc_models.imx500_tpc.latest import \ - get_op_quantization_configs -import model_compression_toolkit as mct -from tests.common_tests.helpers.prep_graph_for_func_test import prepare_graph_with_configs - -if tf.__version__ < "2.6": - from tensorflow.keras.layers import Conv2D, DepthwiseConv2D, Dense, Activation, ReLU, Add -else: - from keras.layers import Conv2D, DepthwiseConv2D, Dense, Activation, ReLU, Add - -keras = tf.keras -layers = keras.layers -activations = keras.activations - -INPUT_SHAPE = (16, 16, 3) - - -def representative_dataset(): - yield [np.random.randn(1, 16, 16, 3).astype(np.float32)] - - -def create_network_1(input_shape): - inputs = layers.Input(shape=input_shape) - x = layers.Conv2D(filters=16, kernel_size=(3, 3))(inputs) - y = layers.Conv2D(filters=16, kernel_size=(1, 1), activation='relu')(x) - return tf.keras.models.Model(inputs=inputs, outputs=y) - - -def create_network_2(input_shape): - inputs = layers.Input(shape=input_shape) - x = layers.Conv2D(filters=16, kernel_size=(3, 3))(inputs) - x = layers.Conv2D(filters=16, kernel_size=(1, 1), activation='tanh')(x) - x = layers.Conv2D(filters=16, kernel_size=(3, 3))(x) - x = layers.ReLU()(x) - x = layers.Conv2D(filters=16, kernel_size=(1, 1))(x) - x = activations.sigmoid(x) - y = layers.Conv2D(filters=16, kernel_size=(1, 1), activation='swish')(x) - return tf.keras.models.Model(inputs=inputs, outputs=y) - - -def create_network_3(input_shape): - inputs = layers.Input(shape=input_shape) - x = layers.Conv2D(filters=16, kernel_size=(3, 3))(inputs) - x = layers.Conv2D(filters=16, kernel_size=(1, 1), activation='tanh')(x) - x = layers.Conv2D(filters=16, kernel_size=(3, 3), activation='relu')(x) - x = layers.Conv2D(filters=16, kernel_size=(1, 1))(x) - x = activations.sigmoid(x) - y = layers.Conv2D(filters=16, kernel_size=(1, 1), activation='swish')(x) - return tf.keras.models.Model(inputs=inputs, outputs=y) - - -def create_network_4(input_shape): - inputs = layers.Input(shape=input_shape) - x = layers.Conv2D(filters=3, kernel_size=(1, 1), padding='same', activation='swish')(inputs) - x1 = layers.Add()([x, inputs]) - x2 = layers.Conv2D(filters=3, kernel_size=(2, 2), padding='same', activation='swish')(x1) - x2 = layers.Add()([x1, x2]) - x2 = layers.Conv2D(filters=3, kernel_size=(1, 1), padding='same', activation='relu')(x2) - x3 = layers.Conv2D(filters=3, kernel_size=(2, 2), padding='same')(x2) - x3 = layers.ReLU()(x3) - x3 = layers.Add()([x2, x3]) - x3 = layers.Flatten()(x3) - x3 = layers.Dense(units=16)(x3) - x3 = activations.swish(x3) - y = layers.Dense(units=16, activation='swish')(x3) - return tf.keras.models.Model(inputs=inputs, outputs=y) - - -def generate_base_tpc(operator_set, fusing_patterns): - base_config, mixed_precision_cfg_list, default_config = get_op_quantization_configs() - default_configuration_options = schema.QuantizationConfigOptions(quantization_configurations=tuple( - [default_config])) - generated_tp = schema.TargetPlatformCapabilities( - default_qco=default_configuration_options, - tpc_minor_version=None, - tpc_patch_version=None, - tpc_platform_type=None, - operator_set=tuple(operator_set), - fusing_patterns=tuple(fusing_patterns), - add_metadata=False, name='layer_fusing_test') - - return generated_tp - - -def get_tpc_1(): - base_config, mixed_precision_cfg_list, default_config = get_op_quantization_configs() - mixed_precision_configuration_options = schema.QuantizationConfigOptions(quantization_configurations=tuple(mixed_precision_cfg_list), - base_config=base_config) - conv = schema.OperatorsSet(name="Conv", qc_options=mixed_precision_configuration_options) - any_relu = schema.OperatorsSet(name="AnyReLU") - operator_set = [conv, any_relu] - # Define fusions - fusing_patterns = [schema.Fusing(operator_groups=(conv, any_relu))] - - generated_tp = generate_base_tpc(operator_set, fusing_patterns) - - return generated_tp - -def get_tpc_2(): - base_config, mixed_precision_cfg_list, default_config = get_op_quantization_configs() - mixed_precision_configuration_options = schema.QuantizationConfigOptions(quantization_configurations=tuple(mixed_precision_cfg_list), - base_config=base_config) - conv = schema.OperatorsSet(name="Conv", qc_options=mixed_precision_configuration_options) - any_relu = schema.OperatorsSet(name="AnyReLU") - swish = schema.OperatorsSet(name="Swish") - sigmoid = schema.OperatorsSet(name="Sigmoid") - tanh = schema.OperatorsSet(name="Tanh") - operator_set = [conv, any_relu, swish, sigmoid, tanh] - activations_after_conv_to_fuse = schema.OperatorSetGroup(operators_set=[any_relu, swish, sigmoid, tanh]) - # Define fusions - fusing_patterns = [schema.Fusing(operator_groups=(conv, activations_after_conv_to_fuse))] - - generated_tp = generate_base_tpc(operator_set, fusing_patterns) - - return generated_tp - -def get_tpc_3(): - base_config, mixed_precision_cfg_list, default_config = get_op_quantization_configs() - mixed_precision_configuration_options = schema.QuantizationConfigOptions(quantization_configurations=tuple(mixed_precision_cfg_list), - base_config=base_config) - conv = schema.OperatorsSet(name="Conv", qc_options=mixed_precision_configuration_options) - any_relu = schema.OperatorsSet(name="AnyReLU") - operator_set = [conv, any_relu] - # Define fusions - fusing_patterns = [schema.Fusing(operator_groups=(conv, any_relu))] - - generated_tp = generate_base_tpc(operator_set, fusing_patterns) - - return generated_tp - - -def get_tpc_4(): - base_config, mixed_precision_cfg_list, default_config = get_op_quantization_configs() - mixed_precision_configuration_options = schema.QuantizationConfigOptions(quantization_configurations=tuple(mixed_precision_cfg_list), - base_config=base_config) - conv = schema.OperatorsSet(name="Conv", qc_options=mixed_precision_configuration_options) - fc = schema.OperatorsSet(name="FullyConnected", qc_options=mixed_precision_configuration_options) - any_relu = schema.OperatorsSet(name="AnyReLU") - add = schema.OperatorsSet(name="Add") - swish = schema.OperatorsSet(name="Swish") - activations_to_fuse = schema.OperatorSetGroup(operators_set=[any_relu, swish]) - operator_set = [conv, fc, any_relu, add, swish] - # Define fusions - fusing_patterns = [schema.Fusing(operator_groups=(conv, activations_to_fuse)), - schema.Fusing(operator_groups=(conv, add, activations_to_fuse)), - schema.Fusing(operator_groups=(conv, activations_to_fuse, add)), - schema.Fusing(operator_groups=(fc, activations_to_fuse))] - - generated_tp = generate_base_tpc(operator_set, fusing_patterns) - - return generated_tp - - -def get_type(fusion): - fusion_types = [x.type for x in fusion] - return fusion_types - - -class TestLayerFusing(unittest.TestCase): - def _compare(self, fused_nodes, expected_fusions): - self.assertTrue(len(fused_nodes) == len(expected_fusions), - msg=f'Number of fusions is not as expected!') - type_names = lambda types_list: [t.__name__ for t in types_list] - for i, fusion in enumerate(fused_nodes): - self.assertTrue(get_type(fusion) == expected_fusions[i] or - type_names(get_type(fusion)) == type_names(expected_fusions[i]), - msg=f'Miss-match fusion compared to expected!') - - def test_layer_fusing_1(self): - expected_fusions = [[Conv2D, Activation]] - model = create_network_1(INPUT_SHAPE) - - qc = QuantizationConfig(custom_tpc_opset_to_layer={"Conv": CustomOpsetLayers([Conv2D]), - "AnyReLU": CustomOpsetLayers([tf.nn.relu, - LayerFilterParams(ReLU, negative_slope=0.0), - LayerFilterParams(Activation, activation="relu")])}) - - fusion_graph = prepare_graph_with_configs(model, KerasImplementation(), DEFAULT_KERAS_INFO, - representative_dataset, lambda name, _tp: get_tpc_1(), - attach2fw=AttachTpcToKeras(), qc=qc) - - self._compare(fusion_graph.fused_nodes, expected_fusions) - - def test_layer_fusing_2(self): - expected_fusions = [[Conv2D, Activation], [Conv2D, ReLU], [Conv2D, tf.nn.sigmoid], [Conv2D, Activation]] - model = create_network_2(INPUT_SHAPE) - - qc = QuantizationConfig(custom_tpc_opset_to_layer={"Conv": CustomOpsetLayers([Conv2D]), - "AnyReLU": CustomOpsetLayers([tf.nn.relu, - LayerFilterParams(ReLU, negative_slope=0.0), - LayerFilterParams(Activation, - activation="relu")]), - "Swish": CustomOpsetLayers([tf.nn.swish, LayerFilterParams(Activation, - activation="swish")]), - "Sigmoid": CustomOpsetLayers([tf.nn.sigmoid, LayerFilterParams(Activation, - activation="sigmoid")]), - "Tanh": CustomOpsetLayers([tf.nn.tanh, LayerFilterParams(Activation, - activation="tanh")])}) - - fusion_graph = prepare_graph_with_configs(model, KerasImplementation(), DEFAULT_KERAS_INFO, - representative_dataset, lambda name, _tp: get_tpc_2(), - attach2fw=AttachTpcToKeras(), qc=qc) - - self._compare(fusion_graph.fused_nodes, expected_fusions) - - def test_layer_fusing_3(self): - expected_fusions = [[Conv2D, Activation]] - model = create_network_3(INPUT_SHAPE) - - qc = QuantizationConfig(custom_tpc_opset_to_layer={"Conv": CustomOpsetLayers([Conv2D]), - "AnyReLU": CustomOpsetLayers([tf.nn.relu, - LayerFilterParams(ReLU, negative_slope=0.0), - LayerFilterParams(Activation, - activation="relu")])}) - - fusion_graph = prepare_graph_with_configs(model, KerasImplementation(), DEFAULT_KERAS_INFO, - representative_dataset, lambda name, _tp: get_tpc_3(), - attach2fw=AttachTpcToKeras(), qc=qc) - - self._compare(fusion_graph.fused_nodes, expected_fusions) - - def test_layer_fusing_4(self): - expected_fusions = [[Conv2D, Activation, Add], [Conv2D, Activation, Add], [Conv2D, Activation], - [Conv2D, ReLU, Add], [Dense, tf.nn.silu], [Dense, Activation]] - model = create_network_4(INPUT_SHAPE) - - qc = QuantizationConfig(custom_tpc_opset_to_layer={ - "Conv": CustomOpsetLayers([Conv2D]), - "FullyConnected": CustomOpsetLayers([Dense]), - "AnyReLU": CustomOpsetLayers([tf.nn.relu, - LayerFilterParams(ReLU, negative_slope=0.0), - LayerFilterParams(Activation, - activation="relu")]), - "Add": CustomOpsetLayers([tf.add, Add]), - "Swish": CustomOpsetLayers([tf.nn.swish, LayerFilterParams(Activation, activation="swish")]), - }) - - fusion_graph = prepare_graph_with_configs(model, KerasImplementation(), DEFAULT_KERAS_INFO, - representative_dataset, lambda name, _tp: get_tpc_4(), - attach2fw=AttachTpcToKeras(), qc=qc) - - self._compare(fusion_graph.fused_nodes, expected_fusions) diff --git a/tests/keras_tests/function_tests/test_weights_activation_split_substitution.py b/tests/keras_tests/function_tests/test_weights_activation_split_substitution.py index 998e7bf00..8d29ebf51 100644 --- a/tests/keras_tests/function_tests/test_weights_activation_split_substitution.py +++ b/tests/keras_tests/function_tests/test_weights_activation_split_substitution.py @@ -91,9 +91,16 @@ def setup_test(in_model, keras_impl, mixed_precision_candidates_list): attach2fw=AttachTpcToKeras(), qc=QuantizationConfig(custom_tpc_opset_to_layer={"Input": CustomOpsetLayers([InputLayer])})) + # Validation is skipped because fusing information is not relevant for the virtual graph. + # Therefore, validation checks are disabled before the virtual graph substitution and + # re-enabled once it completes. + graph.skip_validation_check = True + # Split graph substitution split_graph = substitute(copy.deepcopy(graph), [WeightsActivationSplit()]) + graph.skip_validation_check = False + return graph, split_graph diff --git a/tests/pytorch_tests/function_tests/layer_fusing_test.py b/tests/pytorch_tests/function_tests/layer_fusing_test.py deleted file mode 100644 index 399b427f8..000000000 --- a/tests/pytorch_tests/function_tests/layer_fusing_test.py +++ /dev/null @@ -1,307 +0,0 @@ -# Copyright 2022 Sony Semiconductor Israel, Inc. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================== -import torch -import torch.nn as nn -from torch.nn import Conv2d, ReLU, SiLU, Sigmoid, Linear, Hardtanh -from torch.nn.functional import relu, relu6 - -import model_compression_toolkit.target_platform_capabilities.schema.mct_current_schema as schema -from model_compression_toolkit.core import QuantizationConfig -from model_compression_toolkit.core.pytorch.default_framework_info import DEFAULT_PYTORCH_INFO -from model_compression_toolkit.core.pytorch.pytorch_implementation import PytorchImplementation -from model_compression_toolkit.core.common.quantization.quantization_config import CustomOpsetLayers -from model_compression_toolkit.target_platform_capabilities.targetplatform2framework import LayerFilterParams -from model_compression_toolkit.target_platform_capabilities.targetplatform2framework.attach2pytorch import \ - AttachTpcToPytorch -from model_compression_toolkit.target_platform_capabilities.tpc_models.imx500_tpc.latest import \ - get_op_quantization_configs -from tests.common_tests.helpers.prep_graph_for_func_test import prepare_graph_with_configs -from tests.pytorch_tests.model_tests.base_pytorch_test import BasePytorchTest - -import model_compression_toolkit as mct - - - -class BaseLayerFusingTest(BasePytorchTest): - - def __init__(self, unit_test): - super().__init__(unit_test=unit_test) - self.expected_fusions = [] - - def create_inputs_shape(self): - return [[self.val_batch_size, 3, 16, 16]] - - def representative_data_gen(self): - input_shapes = self.create_inputs_shape() - yield self.generate_inputs(input_shapes) - - def get_type(self, fusion): - fusion_types = [x.type for x in fusion] - return fusion_types - - def _compare(self, fused_nodes): - self.unit_test.assertTrue(len(fused_nodes) == len(self.expected_fusions), - msg=f'Number of fusions is not as expected!') - for i, fusion in enumerate(fused_nodes): - self.unit_test.assertTrue(self.get_type(fusion) == self.expected_fusions[i], - msg=f'Miss-match fusion compared to expected!') - - -class LayerFusingTest1(BaseLayerFusingTest): - def __init__(self, unit_test): - super().__init__(unit_test) - self.expected_fusions = [[nn.Conv2d, nn.ReLU]] - self.attach2fw = AttachTpcToPytorch() - - def get_tpc(self): - base_config, mixed_precision_cfg_list, default_config = get_op_quantization_configs() - default_configuration_options = schema.QuantizationConfigOptions(quantization_configurations=tuple([default_config])) - mixed_precision_configuration_options = schema.QuantizationConfigOptions(quantization_configurations=tuple(mixed_precision_cfg_list), - base_config=base_config) - conv = schema.OperatorsSet(name="Conv", qc_options=mixed_precision_configuration_options) - any_relu = schema.OperatorsSet(name="ReLU") - operator_set = [conv, any_relu] - # Define fusions - fusing_patterns = [schema.Fusing(operator_groups=(conv, any_relu))] - generated_tp = schema.TargetPlatformCapabilities(default_qco=default_configuration_options, - tpc_minor_version=None, - tpc_patch_version=None, - tpc_platform_type=None, - operator_set=tuple(operator_set), - fusing_patterns=tuple(fusing_patterns), - name='layer_fusing_test') - - return generated_tp - - def run_test(self, seed=0): - model_float = self.LayerFusingNetTest() - - graph = prepare_graph_with_configs(model_float, PytorchImplementation(), DEFAULT_PYTORCH_INFO, - self.representative_data_gen, lambda name, _tp: self.get_tpc(), - attach2fw=self.attach2fw) - - self._compare(graph.fused_nodes) - - class LayerFusingNetTest(nn.Module): - def __init__(self): - super().__init__() - self.conv1 = nn.Conv2d(3, 16, kernel_size=(3, 3)) - self.conv2 = nn.Conv2d(16, 32, kernel_size=(1, 1)) - self.relu = nn.ReLU() - - def forward(self, x): - x = self.conv1(x) - x = self.conv2(x) - y = self.relu(x) - return y - - -class LayerFusingTest2(BaseLayerFusingTest): - def __init__(self, unit_test): - super().__init__(unit_test) - self.expected_fusions = [[Conv2d, Hardtanh], [Conv2d, ReLU], [Conv2d, Sigmoid], [Conv2d, SiLU]] - - def get_tpc(self): - base_config, mixed_precision_cfg_list, default_config = get_op_quantization_configs() - mixed_precision_configuration_options = schema.QuantizationConfigOptions(quantization_configurations=tuple(mixed_precision_cfg_list), - base_config=base_config) - default_configuration_options = schema.QuantizationConfigOptions(quantization_configurations=tuple([default_config])) - conv = schema.OperatorsSet(name="Conv", qc_options=mixed_precision_configuration_options) - any_act = schema.OperatorsSet(name="AnyAct") - operator_set = [conv, any_act] - # Define fusions - fusing_patterns = [schema.Fusing(operator_groups=(conv, any_act))] - generated_tp = schema.TargetPlatformCapabilities(default_qco=default_configuration_options, - tpc_minor_version=None, - tpc_patch_version=None, - tpc_platform_type=None, - operator_set=tuple(operator_set), - fusing_patterns=tuple(fusing_patterns), - name='layer_fusing_test') - - return generated_tp - - def run_test(self, seed=0): - model_float = self.LayerFusingNetTest() - graph = prepare_graph_with_configs(model_float, PytorchImplementation(), DEFAULT_PYTORCH_INFO, - self.representative_data_gen, lambda name, _tp: self.get_tpc(), - attach2fw=AttachTpcToPytorch(), - qc=QuantizationConfig( - custom_tpc_opset_to_layer={"AnyAct": CustomOpsetLayers([ReLU, relu6, relu, SiLU, Sigmoid, - LayerFilterParams(Hardtanh, min_val=0)])})) - - self._compare(graph.fused_nodes) - - class LayerFusingNetTest(nn.Module): - def __init__(self): - super().__init__() - self.conv1 = nn.Conv2d(3, 32, kernel_size=(3, 3)) - self.conv2 = nn.Conv2d(32, 32, kernel_size=(1, 1)) - self.conv3 = nn.Conv2d(32, 32, kernel_size=(3, 3)) - self.conv4 = nn.Conv2d(32, 64, kernel_size=(1, 1)) - self.conv5 = nn.Conv2d(64, 64, kernel_size=(2, 2)) - self.relu = nn.ReLU() - self.tanh = Hardtanh(min_val=0) - self.swish = nn.SiLU() - self.sigmoid = nn.Sigmoid() - - def forward(self, x): - x = self.conv1(x) - x = self.conv2(x) - x = self.tanh(x) - x = self.conv3(x) - x = self.relu(x) - x = self.conv4(x) - x = self.sigmoid(x) - x = self.conv5(x) - y = self.swish(x) - return y - - -class LayerFusingTest3(BaseLayerFusingTest): - def __init__(self, unit_test): - super().__init__(unit_test) - self.expected_fusions = [[Conv2d, ReLU]] - - def get_tpc(self): - base_config, mixed_precision_cfg_list, default_config = get_op_quantization_configs() - mixed_precision_configuration_options = schema.QuantizationConfigOptions(quantization_configurations=tuple(mixed_precision_cfg_list), - base_config=base_config) - default_configuration_options = schema.QuantizationConfigOptions(quantization_configurations=tuple([default_config])) - conv = schema.OperatorsSet(name="Conv", qc_options=mixed_precision_configuration_options) - any_act = schema.OperatorsSet(name="AnyAct") - operator_set = [conv, any_act] - # Define fusions - fusing_patterns = [schema.Fusing(operator_groups=(conv, any_act))] - generated_tp = schema.TargetPlatformCapabilities(default_qco=default_configuration_options, - tpc_minor_version=None, - tpc_patch_version=None, - tpc_platform_type=None, - operator_set=tuple(operator_set), - fusing_patterns=tuple(fusing_patterns), - name='layer_fusing_test') - return generated_tp - - def run_test(self, seed=0): - model_float = self.LayerFusingNetTest() - graph = prepare_graph_with_configs(model_float, PytorchImplementation(), DEFAULT_PYTORCH_INFO, - self.representative_data_gen, lambda name, _tp: self.get_tpc(), - attach2fw=AttachTpcToPytorch(), - qc=QuantizationConfig( - custom_tpc_opset_to_layer={"AnyAct": CustomOpsetLayers([ReLU, relu6, relu])})) - - self._compare(graph.fused_nodes) - - class LayerFusingNetTest(nn.Module): - def __init__(self): - super().__init__() - self.conv1 = nn.Conv2d(3, 32, kernel_size=(3, 3)) - self.conv2 = nn.Conv2d(32, 32, kernel_size=(1, 1)) - self.conv3 = nn.Conv2d(32, 32, kernel_size=(3, 3)) - self.conv4 = nn.Conv2d(32, 64, kernel_size=(1, 1)) - self.conv5 = nn.Conv2d(64, 64, kernel_size=(2, 2)) - self.relu = nn.ReLU() - self.tanh = nn.Tanh() - self.swish = nn.SiLU() - self.sigmoid = nn.Sigmoid() - - def forward(self, x): - x = self.conv1(x) - x = self.conv2(x) - x = self.tanh(x) - x = self.conv3(x) - x = self.relu(x) - x = self.conv4(x) - x = self.sigmoid(x) - x = self.conv5(x) - y = self.swish(x) - return y - - -class LayerFusingTest4(BaseLayerFusingTest): - def __init__(self, unit_test): - super().__init__(unit_test) - self.expected_fusions = [[Conv2d, SiLU, torch.add], [Conv2d, SiLU, torch.add], [Conv2d, ReLU], - [Conv2d, ReLU, torch.add], [Linear, SiLU], [Linear, SiLU]] - - def get_tpc(self): - base_config, mixed_precision_cfg_list, default_config = get_op_quantization_configs() - mixed_precision_configuration_options = schema.QuantizationConfigOptions(quantization_configurations=tuple(mixed_precision_cfg_list), - base_config=base_config) - default_configuration_options = schema.QuantizationConfigOptions(quantization_configurations=tuple([default_config])) - conv = schema.OperatorsSet(name=schema.OperatorSetNames.CONV, qc_options=mixed_precision_configuration_options) - fc = schema.OperatorsSet(name=schema.OperatorSetNames.FULLY_CONNECTED, qc_options=mixed_precision_configuration_options) - relu = schema.OperatorsSet(name=schema.OperatorSetNames.RELU) - add = schema.OperatorsSet(name=schema.OperatorSetNames.ADD) - swish = schema.OperatorsSet(name=schema.OperatorSetNames.SWISH) - operator_set = [conv, fc, relu, add, swish] - activations_to_fuse = schema.OperatorSetGroup(operators_set=[relu, swish]) - # Define fusions - fusing_patterns = [schema.Fusing(operator_groups=(conv, activations_to_fuse)), - schema.Fusing(operator_groups=(conv, add, activations_to_fuse)), - schema.Fusing(operator_groups=(conv, activations_to_fuse, add)), - schema.Fusing(operator_groups=(fc, activations_to_fuse))] - - generated_tp = schema.TargetPlatformCapabilities(default_qco=default_configuration_options, - tpc_minor_version=None, - tpc_patch_version=None, - tpc_platform_type=None, - operator_set=tuple(operator_set), - fusing_patterns=tuple(fusing_patterns), - name='layer_fusing_test') - - return generated_tp - - def run_test(self, seed=0): - model_float = self.LayerFusingNetTest() - graph = prepare_graph_with_configs(model_float, PytorchImplementation(), DEFAULT_PYTORCH_INFO, - self.representative_data_gen, lambda name, _tp: self.get_tpc(), - attach2fw=AttachTpcToPytorch()) - - self._compare(graph.fused_nodes) - - class LayerFusingNetTest(nn.Module): - def __init__(self): - super().__init__() - self.conv1 = nn.Conv2d(3, 3, kernel_size=(3, 3), padding='same') - self.conv2 = nn.Conv2d(3, 3, kernel_size=(1, 1), padding='same') - self.conv3 = nn.Conv2d(3, 3, kernel_size=(3, 3), padding='same') - self.conv4 = nn.Conv2d(3, 3, kernel_size=(1, 1), padding='same') - self.conv5 = nn.Conv2d(3, 3, kernel_size=(3, 3), padding='same') - self.conv6 = nn.Conv2d(3, 3, kernel_size=(1, 1), padding='same') - self.relu = nn.ReLU() - self.swish = nn.SiLU() - self.flatten = nn.Flatten() - self.dense1 = nn.Linear(768, out_features=16) - self.dense2 = nn.Linear(16, out_features=16) - - def forward(self, inputs): - x = self.conv1(inputs) - x = self.swish(x) - x1 = torch.add(inputs, x) - x2 = self.conv2(x1) - x2 = self.swish(x2) - x2 = torch.add(x1, x2) - x2 = self.conv3(x2) - x2 = self.relu(x2) - x3 = self.conv4(x2) - x3 = self.relu(x3) - x3 = torch.add(x3, x2) - x3 = self.flatten(x3) - x3 = self.dense1(x3) - x3 = self.swish(x3) - x3 = self.dense2(x3) - y = self.swish(x3) - return y diff --git a/tests/pytorch_tests/function_tests/test_function_runner.py b/tests/pytorch_tests/function_tests/test_function_runner.py index c208ccdf2..47fb68847 100644 --- a/tests/pytorch_tests/function_tests/test_function_runner.py +++ b/tests/pytorch_tests/function_tests/test_function_runner.py @@ -20,8 +20,6 @@ Conv2D2BNInfoCollectionTest, Conv2DBNChainInfoCollectionTest, BNChainInfoCollectionTest, \ BNLayerInfoCollectionTest, INP2BNInfoCollectionTest from tests.pytorch_tests.function_tests.get_gptq_config_test import TestGetGPTQConfig -from tests.pytorch_tests.function_tests.layer_fusing_test import LayerFusingTest1, LayerFusingTest2, LayerFusingTest3, \ - LayerFusingTest4 from tests.pytorch_tests.function_tests.set_device_test import SetDeviceTest from tests.pytorch_tests.function_tests.set_layer_to_bitwidth_test import TestSetLayerToBitwidthWeights, \ TestSetLayerToBitwidthActivation @@ -116,15 +114,6 @@ def test_hessian_service(self): FetchHessianMultipleNodesTest(self).run_test() DoubleFetchHessianTest(self).run_test() - def test_layer_fusing(self): - """ - This test checks the Fusion mechanism in Pytorch. - """ - LayerFusingTest1(self).run_test() - LayerFusingTest2(self).run_test() - LayerFusingTest3(self).run_test() - LayerFusingTest4(self).run_test() - def test_mixed_precision_set_bitwidth(self): """ This test checks the functionality of setting a configurable layer's weights bit-width for mixed precision diff --git a/tests/pytorch_tests/model_tests/feature_models/compute_max_cut_test.py b/tests/pytorch_tests/model_tests/feature_models/compute_max_cut_test.py index e72c55187..bcfe430bd 100644 --- a/tests/pytorch_tests/model_tests/feature_models/compute_max_cut_test.py +++ b/tests/pytorch_tests/model_tests/feature_models/compute_max_cut_test.py @@ -15,6 +15,7 @@ import torch.nn as nn import model_compression_toolkit as mct +from model_compression_toolkit.core.common.fusion.fusing_info import FUSED_OP_ID_PREFIX from tests.common_tests.helpers.tpcs_for_tests.v2.tpc import get_tpc from tests.pytorch_tests.model_tests.base_pytorch_feature_test import BasePytorchFeatureNetworkTest from model_compression_toolkit.target_platform_capabilities.constants import IMX500_TP_MODEL @@ -54,15 +55,15 @@ def compare(self, quantized_model, float_model, input_x=None, quantization_info= _metadata = get_metadata(quantized_model) self.unit_test.assertEqual(_metadata['scheduling_info']['operators_scheduling'], ['DummyPlaceHolder:x', - 'FusedLayerType:FusedNode_conv2d_1_bn_relu_1', - 'FusedLayerType:FusedNode_conv2d_2_bn_relu_2', + f'FusedLayerType:{FUSED_OP_ID_PREFIX}conv2d_1_bn_relu_1', + f'FusedLayerType:{FUSED_OP_ID_PREFIX}conv2d_2_bn_relu_2', 'add:add']) self.unit_test.assertEqual(_metadata['scheduling_info']['max_cut'], 256 * 3) expected_fused_nodes_mapping = { - 'conv2d_1_bn': 'FusedNode_conv2d_1_bn_relu_1', - 'relu_1': 'FusedNode_conv2d_1_bn_relu_1', - 'conv2d_2_bn': 'FusedNode_conv2d_2_bn_relu_2', - 'relu_2': 'FusedNode_conv2d_2_bn_relu_2' + 'conv2d_1_bn': f"{FUSED_OP_ID_PREFIX}conv2d_1_bn_relu_1", + 'relu_1': f"{FUSED_OP_ID_PREFIX}conv2d_1_bn_relu_1", + 'conv2d_2_bn': f"{FUSED_OP_ID_PREFIX}conv2d_2_bn_relu_2", + 'relu_2': f"{FUSED_OP_ID_PREFIX}conv2d_2_bn_relu_2" } self.unit_test.assertEqual(_metadata['scheduling_info']['fused_nodes_mapping'], expected_fused_nodes_mapping) diff --git a/tests_pytest/_fw_tests_common_base/base_ru_integration_test.py b/tests_pytest/_fw_tests_common_base/base_ru_integration_test.py index b51e86d90..3d5c3005e 100644 --- a/tests_pytest/_fw_tests_common_base/base_ru_integration_test.py +++ b/tests_pytest/_fw_tests_common_base/base_ru_integration_test.py @@ -212,8 +212,14 @@ def test_mult_output_activation(self): assert self._extract_values(detailed_orig[RUTarget.WEIGHTS]) == exp_w_ru assert self._extract_values(detailed_orig[RUTarget.BOPS]) == exp_bops + # Validation is skipped because fusing information is not relevant for the virtual graph. + # Therefore, validation checks are disabled before the virtual graph substitution and + # re-enabled once it completes. + graph.skip_validation_check = True virtual_graph = substitute(copy.deepcopy(graph), self.fw_impl.get_substitutions_virtual_weights_activation_coupling()) + graph.skip_validation_check = False + assert len(virtual_graph.nodes) == 8 assert len([n for n in virtual_graph.nodes if isinstance(n, VirtualActivationWeightsNode)]) == 1 assert len([n for n in virtual_graph.nodes if isinstance(n, VirtualSplitActivationNode)]) == 3 diff --git a/tests_pytest/_fw_tests_common_base/fusing/__init__.py b/tests_pytest/_fw_tests_common_base/fusing/__init__.py new file mode 100644 index 000000000..5397dea24 --- /dev/null +++ b/tests_pytest/_fw_tests_common_base/fusing/__init__.py @@ -0,0 +1,14 @@ +# Copyright 2025 Sony Semiconductor Israel, Inc. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== diff --git a/tests_pytest/_fw_tests_common_base/fusing/base_fusing_info_generator_test.py b/tests_pytest/_fw_tests_common_base/fusing/base_fusing_info_generator_test.py new file mode 100644 index 000000000..e492277bf --- /dev/null +++ b/tests_pytest/_fw_tests_common_base/fusing/base_fusing_info_generator_test.py @@ -0,0 +1,173 @@ +# Copyright 2025 Sony Semiconductor Israel, Inc. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +import random +from typing import Callable, List + +import abc + +import pytest +from mct_quantizers import QuantizationMethod + +from model_compression_toolkit.core import FrameworkInfo +from model_compression_toolkit.core.common.framework_implementation import FrameworkImplementation +from model_compression_toolkit.core.common.fusion.fusing_info import FusingInfo +from model_compression_toolkit.core.common.fusion.graph_fuser import GraphFuser +from model_compression_toolkit.core.common.quantization.candidate_node_quantization_config import \ + CandidateNodeQuantizationConfig +from model_compression_toolkit.core.graph_prep_runner import graph_preparation_runner +import model_compression_toolkit.target_platform_capabilities.schema.mct_current_schema as schema +from tests_pytest._test_util.tpc_util import minimal_cfg_options + + +class MockNodeActivationQuantizationConfig: + def __init__(self, n_bits: int): + self.activation_n_bits = n_bits + +def random_activation_configs(): + num_candidates = random.choice([1, 2, 3]) + bits_list = random.sample(range(2, 9), k=num_candidates) + qcs = [ + CandidateNodeQuantizationConfig( + weights_quantization_cfg=None, + activation_quantization_cfg=MockNodeActivationQuantizationConfig(n_bits=nb) + ) + for nb in bits_list + ] + return bits_list, qcs + +def get_activation_mp_options(last_node_activation_nbits): + options = tuple([schema.OpQuantizationConfig( + default_weight_attr_config={}, + attr_weights_configs_mapping={}, + activation_quantization_method=QuantizationMethod.POWER_OF_TWO, + activation_n_bits=a_nbits, + supported_input_activation_n_bits=[8], + enable_activation_quantization=True, + quantization_preserving=False, + fixed_scale=None, + fixed_zero_point=None, + simd_size=32, + signedness=schema.Signedness.AUTO) for a_nbits in last_node_activation_nbits]) + + cfg_options = schema.QuantizationConfigOptions(quantization_configurations=options, base_config=options[0]) + + return cfg_options + +class BaseFusingInfoGeneratorTest(abc.ABC): + + fw_impl: FrameworkImplementation + fw_info: FrameworkInfo + attach_to_fw_func: Callable + expected_fi: FusingInfo + last_node_activation_nbits: List[int] + + def _data_gen(self): + raise NotImplementedError() + + def _get_model(self): + raise NotImplementedError() + + def _get_tpc(self, default_quant_cfg_options): + raise NotImplementedError() + + def _get_qc(self): + raise NotImplementedError() + + + @pytest.fixture + def graph_with_fusion_metadata(self): + """ + Creates a graph with fusing metadata based on a generated model and a predefined configuration. + Ensures all required components (framework implementation, framework info, etc.) are present. + """ + assert self._data_gen is not None + assert self.fw_impl is not None + assert self.fw_info is not None + assert self.attach_to_fw_func is not None + assert self.expected_fi is not None + assert self.last_node_activation_nbits is not None + + self.fqc = self.attach_to_fw_func(self._get_tpc(minimal_cfg_options()), + self._get_qc().custom_tpc_opset_to_layer) + + graph_with_fusion_metadata = graph_preparation_runner(self._get_model(), + self._data_gen, + self._get_qc(), + fw_info=self.fw_info, + fw_impl=self.fw_impl, + fqc=self.fqc, + mixed_precision_enable=True, + running_gptq=False) + return graph_with_fusion_metadata + + @pytest.fixture + def fused_graph(self, graph_with_fusion_metadata): + return GraphFuser().apply_node_fusion(graph_with_fusion_metadata) + + def test_expected_fusing_info(self, graph_with_fusion_metadata): + actual_fi = graph_with_fusion_metadata.fusing_info + assert self.expected_fi.node_to_fused_node_map == actual_fi.node_to_fused_node_map + + def test_expected_fused_graph(self, fused_graph): + expected_fused_nodes = self.expected_fi.fusing_data + graph_node_names = [node.name for node in fused_graph.nodes] + + for fused_node_name, original_nodes in expected_fused_nodes.items(): + # 1. Fused node must exist + assert fused_node_name in graph_node_names, f"Fused node '{fused_node_name}' not found in graph." + fused_node = fused_graph.find_node_by_name(fused_node_name) + assert len(fused_node) == 1, f"Expected to find a single node, but found {len(fused_node)}" + fused_node = fused_node[0] + + # 2. Original nodes should not exist anymore + for node in original_nodes: + assert node.name not in graph_node_names, ( + f"Original node '{node.name}' should be fused into '{fused_node_name}', " + f"but it's still in the graph." + ) + + # 3. Final quantization configs + if original_nodes[0].final_weights_quantization_cfg is not None: + assert fused_node.final_weights_quantization_cfg == original_nodes[0].final_weights_quantization_cfg, (f"Incorrect final_weights_quantization_cfg for '{fused_node_name}'") + + if original_nodes[-1].final_activation_quantization_cfg is not None: + assert fused_node.final_activation_quantization_cfg == original_nodes[-1].final_activation_quantization_cfg, (f"Incorrect final_activation_quantization_cfg for '{fused_node_name}'") + + # 4. Candidate quantization configs + expected_candidates = original_nodes[-1].candidates_quantization_cfg + actual_candidates = fused_node.candidates_quantization_cfg + + assert len(actual_candidates) == len(expected_candidates), ( + f"Mismatch in number of candidate quantization configs for '{fused_node_name}'") + + # Extract and sort the n_bits values for comparison + actual_nbits = sorted([c.activation_quantization_cfg.activation_n_bits for c in actual_candidates]) + expected_nbits = sorted([c.activation_quantization_cfg.activation_n_bits for c in expected_candidates]) + + assert actual_nbits == expected_nbits, ( + f"Mismatch in activation_n_bits list for '{fused_node_name}': " + f"expected {expected_nbits}, got {actual_nbits}") + + # Optionally also assert all weights quant configs are None + for i, actual in enumerate(actual_candidates): + assert actual.weights_quantization_cfg is None, ( + f"Weights quant config should be None in fused candidate #{i} for '{fused_node_name}'") + + + + + + + diff --git a/tests_pytest/_fw_tests_common_base/fusing/base_graph_with_fusing_metadata_test.py b/tests_pytest/_fw_tests_common_base/fusing/base_graph_with_fusing_metadata_test.py new file mode 100644 index 000000000..379f37faf --- /dev/null +++ b/tests_pytest/_fw_tests_common_base/fusing/base_graph_with_fusing_metadata_test.py @@ -0,0 +1,240 @@ +# Copyright 2025 Sony Semiconductor Israel, Inc. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +from typing import Callable, Any + +import copy + +import abc + +import pytest + + +import model_compression_toolkit.target_platform_capabilities.schema.mct_current_schema as schema +from model_compression_toolkit.core import QuantizationConfig, FrameworkInfo +from model_compression_toolkit.core.common import BaseNode, Graph +from model_compression_toolkit.core.common.framework_implementation import FrameworkImplementation +from model_compression_toolkit.core.common.graph.edge import EDGE_SOURCE_INDEX, EDGE_SINK_INDEX +from model_compression_toolkit.core.graph_prep_runner import graph_preparation_runner +from model_compression_toolkit.core.common.fusion.graph_fuser import GraphFuser +from tests_pytest._test_util.tpc_util import minimal_cfg_options + + +class BaseGraphWithFusingMetadataTest(abc.ABC): + + fw_impl: FrameworkImplementation + fw_info: FrameworkInfo + attach_to_fw_func: Callable + layer_class_relu: Any # needed for test_fail_validate_after_adding_node_that_adds_a_fusion + + def _data_gen(self): + raise NotImplementedError() + + def _get_model(self): + raise NotImplementedError() + + @pytest.fixture + def minimal_tpc_with_fusing(self): + """ + Fixture that provides a minimal Target Platform Capabilities (TPC) config used by the + `graph_with_fusion_metadata` fixture. + + minimal_tpc_with_fusing is used as a fixture to provide graph_with_fusion_metadata, which is a required + fixture for the actual test functions. While minimal_tpc_with_fusing itself isn’t used directly in tests, + defining it as a fixture makes its usage cleaner. + + """ + return schema.TargetPlatformCapabilities( + default_qco=minimal_cfg_options(), + tpc_platform_type='test', + operator_set=[schema.OperatorsSet(name=schema.OperatorSetNames.CONV), + schema.OperatorsSet(name=schema.OperatorSetNames.RELU), + schema.OperatorsSet(name=schema.OperatorSetNames.FULLY_CONNECTED), + schema.OperatorsSet(name=schema.OperatorSetNames.SOFTMAX)], + fusing_patterns=[schema.Fusing(operator_groups=(schema.OperatorsSet(name=schema.OperatorSetNames.CONV), + schema.OperatorsSet(name=schema.OperatorSetNames.RELU))), + schema.Fusing( + operator_groups=(schema.OperatorsSet(name=schema.OperatorSetNames.FULLY_CONNECTED), + schema.OperatorsSet(name=schema.OperatorSetNames.SOFTMAX))), + schema.Fusing( + operator_groups=(schema.OperatorsSet(name=schema.OperatorSetNames.FULLY_CONNECTED), + schema.OperatorsSet(name=schema.OperatorSetNames.SOFTMAX), + schema.OperatorsSet(name=schema.OperatorSetNames.RELU))) + ] + ) + + @pytest.fixture + def graph_with_fusion_metadata(self, minimal_tpc_with_fusing): + """ + Creates a graph with fusing metadata based on a generated model and a predefined configuration. + Ensures all required components (framework implementation, framework info, etc.) are present. + """ + assert self._data_gen is not None + assert self.fw_impl is not None + assert self.fw_info is not None + assert self.attach_to_fw_func is not None + + self.fqc = self.attach_to_fw_func(minimal_tpc_with_fusing) + + graph_with_fusion_metadata = graph_preparation_runner(self._get_model(), + self._data_gen, + QuantizationConfig(), + fw_info=self.fw_info, + fw_impl=self.fw_impl, + fqc=self.fqc, + mixed_precision_enable=False, + running_gptq=False) + return graph_with_fusion_metadata + + def test_expected_fusing_info(self, graph_with_fusion_metadata): + """ + Test that the graph contains expected metadata regard the fusing that should + be found in the model. + """ + actual_fi = graph_with_fusion_metadata.fusing_info + assert len(actual_fi.get_all_fused_operations()) == 2 + assert sorted(actual_fi.get_all_fused_operations().keys()) == ['FusedNode_conv_relu', 'FusedNode_linear_softmax'] + assert actual_fi.node_to_fused_node_map == {'conv': 'FusedNode_conv_relu', + 'relu': 'FusedNode_conv_relu', + 'linear': 'FusedNode_linear_softmax', + 'softmax': 'FusedNode_linear_softmax'} + + def test_disable_act_quantization(self, graph_with_fusion_metadata: Graph): + """Tests that the correct nodes have activation quantization disabled after + calling _disable_nodes_activation_quantization. + """ + for node in graph_with_fusion_metadata.nodes: + for qc in node.candidates_quantization_cfg: + qc.activation_quantization_cfg.enable_activation_quantization = True + + graph_with_fusion_metadata.disable_fused_nodes_activation_quantization() + disabled_nodes = [ + node.name for node in graph_with_fusion_metadata.nodes + if all(not qc.activation_quantization_cfg.enable_activation_quantization + for qc in node.candidates_quantization_cfg) + ] + + expected = ['conv', 'linear'] + assert sorted(disabled_nodes) == expected, f"Expected {expected}, but got {sorted(disabled_nodes)}" + + def test_fail_validate_after_node_removal(self, graph_with_fusion_metadata): + """ + Tests validation failure after removing a node that is part of a fusion pattern. + - Replaces a ReLU node with a new Tanh node. + - Expects validation to fail because ReLU was part of a defined fusion pattern. + """ + relu_node = graph_with_fusion_metadata.find_node_by_name('relu')[0] + new_node = BaseNode( + name='tanh', + framework_attr={}, + input_shape=relu_node.input_shape, + output_shape=relu_node.output_shape, + weights={}, + layer_class="Tanh" + ) + with pytest.raises(ValueError): + graph_with_fusion_metadata.replace_node(relu_node, new_node) + + def test_fail_validate_after_topology_change(self, graph_with_fusion_metadata): + """ + Tests validation failure after modifying the graph topology by adding an unintended edge. + - Adds an edge from Conv2D to Flatten, creating multiple successors. + - Expects validation to fail as the topology no longer follows expected fusing rules. + """ + conv_node = graph_with_fusion_metadata.find_node_by_name('conv')[0] + flatten_node = graph_with_fusion_metadata.find_node_by_name('flatten')[0] + with pytest.raises(ValueError): + graph_with_fusion_metadata.add_edge(conv_node, flatten_node, **{EDGE_SOURCE_INDEX: 1, EDGE_SINK_INDEX: 1}) + + def test_fail_validate_after_adding_node_between_conv_to_relu(self, graph_with_fusion_metadata): + """ + Tests validation failure after inserting a node between fused Conv2D and ReLU layers. + - Removes the edge between Conv2D and ReLU. + - Expects validation to fail as the fusion sequence is broken. + """ + conv_node = graph_with_fusion_metadata.find_node_by_name('conv')[0] + relu_node = graph_with_fusion_metadata.find_node_by_name('relu')[0] + with pytest.raises(ValueError): + graph_with_fusion_metadata.remove_edge(conv_node, relu_node) + with pytest.raises(ValueError): + graph_with_fusion_metadata.validate() + + # After updating the fusing info, make sure validation passes + graph_with_fusion_metadata.fusing_info.remove_fused_operation('FusedNode_conv_relu') + graph_with_fusion_metadata.validate() + + def test_valid_change_in_graph(self, graph_with_fusion_metadata): + """ + Tests validation passes after changing the graph with a change that should not + affect the fusing metadata. + """ + graph_with_fusion_metadata.validate() + # Add softmax node after current softmax + softmax_node = graph_with_fusion_metadata.find_node_by_name('softmax')[0] + new_softmax_node = BaseNode( + name='new_softmax', + framework_attr={}, + input_shape=softmax_node.output_shape, + output_shape=softmax_node.output_shape, + weights={}, + layer_class='softmax' + ) + graph_with_fusion_metadata.add_node(new_softmax_node) + graph_with_fusion_metadata.add_edge(softmax_node, new_softmax_node, **{EDGE_SOURCE_INDEX: 0, EDGE_SINK_INDEX: 0}) + graph_with_fusion_metadata.validate() + + def test_fail_validate_after_serialization_deserialization(self, graph_with_fusion_metadata): + """ + Tests validation failure after serializing and deserializing the graph. + - Serializes and deserializes the graph to ensure stability. + - Breaks fusing by removing an edge and checks if validation fails. + """ + graph_copy = copy.deepcopy(graph_with_fusion_metadata) + graph_copy.validate() + conv_node = graph_copy.find_node_by_name('conv')[0] + relu_node = graph_copy.find_node_by_name('relu')[0] + with pytest.raises(ValueError): + graph_copy.remove_edge(conv_node, relu_node) + + def test_fail_validate_after_adding_node_that_adds_a_fusion(self, graph_with_fusion_metadata): + """ + Tests validation failure after introducing a new fusion pattern by adding a node. + - Adds a ReLU node after Softmax. + - The resulting pattern FullyConnected -> Softmax -> ReLU is a defined fusion. + - Since this pattern didn't exist in the original graph, re-running fusing should raise an error. + """ + # Step 1: Validate original graph (should pass) + graph_with_fusion_metadata.validate() + fuser = GraphFuser() + fuser.apply_node_fusion(graph_with_fusion_metadata) + + # Step 2: Add ReLU node after softmax + softmax_node = graph_with_fusion_metadata.find_node_by_name('softmax')[0] + relu_node = BaseNode( + name='new_relu', + framework_attr={}, + input_shape=softmax_node.output_shape, + output_shape=softmax_node.output_shape, + weights={}, + layer_class=self.layer_class_relu + ) + graph_with_fusion_metadata.add_node(relu_node) + graph_with_fusion_metadata.add_edge(softmax_node, relu_node, **{EDGE_SOURCE_INDEX: 0, EDGE_SINK_INDEX: 0}) + + # Step 3: Run fuser and expect failure due to unexpected new fusion + with pytest.raises(ValueError): + fuser.apply_node_fusion(graph_with_fusion_metadata) + + + diff --git a/tests_pytest/_test_util/tpc_util.py b/tests_pytest/_test_util/tpc_util.py index dad7d8544..c2d3315dc 100644 --- a/tests_pytest/_test_util/tpc_util.py +++ b/tests_pytest/_test_util/tpc_util.py @@ -50,12 +50,11 @@ def build_mp_config_options_for_kernel_bias_ops(base_w_config: AttributeQuantiza return QuantizationConfigOptions(quantization_configurations=mp_configs, base_config=base_op_config) -def minimal_tpc(): +def minimal_cfg_options(): """ - Minimal TPC. Is intended to be used in integration tests, when real TPC (as opposed to mock) is needed, - but we don't care about its content. - - There is also a fixture form by the same name. + Minimal op configuration options. Is intended to be used in integration tests, + when real TPC (as opposed to mock) is needed, and we care about some of its content (like fusing) + but we don't care about the default configuration options. """ op_cfg = OpQuantizationConfig( default_weight_attr_config={}, @@ -72,6 +71,19 @@ def minimal_tpc(): cfg_options = QuantizationConfigOptions(quantization_configurations=[op_cfg]) + return cfg_options + + +def minimal_tpc(): + """ + Minimal TPC. Is intended to be used in integration tests, when real TPC (as opposed to mock) is needed, + but we don't care about its content. + + There is also a fixture form by the same name. + """ + + cfg_options = minimal_cfg_options() + return TargetPlatformCapabilities(default_qco=cfg_options, tpc_platform_type='test', operator_set=None, diff --git a/tests_pytest/common_tests/unit_tests/core/test_fusion_info.py b/tests_pytest/common_tests/unit_tests/core/test_fusion_info.py new file mode 100644 index 000000000..2cf2d237a --- /dev/null +++ b/tests_pytest/common_tests/unit_tests/core/test_fusion_info.py @@ -0,0 +1,221 @@ +# Copyright 2025 Sony Semiconductor Israel, Inc. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + +import pytest +from unittest.mock import Mock + +from model_compression_toolkit.core.common.fusion.fusing_info import FusingInfoGenerator, FUSED_OP_ID_PREFIX, FusingInfo +from model_compression_toolkit.target_platform_capabilities import FrameworkQuantizationCapabilities +from model_compression_toolkit.core.common import BaseNode + + +class MockBaseNode: + """ + Mock implementation of a base graph node. + Allows for equality checks and hashing based on the node name. + """ + + def __init__(self, name: str): + self.name = name + + def __eq__(self, other): + return isinstance(other, MockBaseNode) and self.name == other.name + + def __hash__(self): + return hash(self.name) + + +@pytest.fixture +def fusing_patterns(): + """ + - Returns predefined fusing patterns: Conv2D + ReLU and Linear + Softmax. + """ + return [["Conv2d", "ReLU"], ["Linear", "Softmax"]] + + +@pytest.fixture +def mock_nodes(): + """ + Creates mock nodes representing a simple neural network structure. + - Nodes: Conv2D, ReLU, Linear, Softmax. + """ + node1 = Mock(spec=BaseNode) + node1.name = "conv" + node1.layer_class = "Conv2d" + + node2 = Mock(spec=BaseNode) + node2.name = "relu" + node2.layer_class = "ReLU" + + node3 = Mock(spec=BaseNode) + node3.name = "linear" + node3.layer_class = "Linear" + + node4 = Mock(spec=BaseNode) + node4.name = "softmax" + node4.layer_class = "Softmax" + + return [node1, node2, node3, node4] + + +@pytest.fixture +def mock_graph(mock_nodes): + """ + Creates a mock graph with topologically sorted nodes and defined connectivity. + - Implements `get_next_nodes` and `get_prev_nodes` to maintain linear order. + """ + graph = Mock() + graph.nodes.return_value = mock_nodes + graph.get_topo_sorted_nodes.return_value = mock_nodes + + adjacency = { + mock_nodes[0]: [mock_nodes[1]], # conv -> relu + mock_nodes[1]: [mock_nodes[2]], # relu -> linear + mock_nodes[2]: [mock_nodes[3]], # linear -> softmax + mock_nodes[3]: [] # softmax has no outputs + } + + reverse_adjacency = { + mock_nodes[0]: [], # conv has no inputs + mock_nodes[1]: [mock_nodes[0]], # relu <- conv + mock_nodes[2]: [mock_nodes[1]], # linear <- relu + mock_nodes[3]: [mock_nodes[2]] # softmax <- linear + } + + graph.get_next_nodes.side_effect = lambda node: adjacency.get(node, []) + graph.get_prev_nodes.side_effect = lambda node: reverse_adjacency.get(node, []) + + return graph + + +@pytest.fixture +def fusing_info_generator(fusing_patterns): + """ + Creates a FusingInfoGenerator using the fusing patterns. + """ + return FusingInfoGenerator(fusing_patterns) + + +def test_fusing_info_number_of_operations(mock_graph, fusing_info_generator): + """ + Tests that the correct number of fused operations is detected. + - Expects 2 fused operations: Conv2D + ReLU, Linear + Softmax. + """ + fi = fusing_info_generator.generate_fusing_info(mock_graph) + fused_operations = fi.get_all_fused_operations() + assert len(fused_operations) == 2, "Expected 2 fused operations" + + +def test_fusing_info_operation_contents(mock_graph, fusing_info_generator, mock_nodes): + """ + Tests that the fused operations contain the correct node groups. + - Checks that the correct node names are assigned to each fused operation. + """ + fi = fusing_info_generator.generate_fusing_info(mock_graph) + fused_operations = fi.get_all_fused_operations() + + expected_op1_id = f"{FUSED_OP_ID_PREFIX}conv_relu" + expected_op2_id = f"{FUSED_OP_ID_PREFIX}linear_softmax" + + assert expected_op1_id in fused_operations, f"{expected_op1_id} not found" + assert expected_op2_id in fused_operations, f"{expected_op2_id} not found" + + conv_node, relu_node, linear_node, softmax_node = mock_nodes + + assert [f.name for f in fused_operations[expected_op1_id]] == [conv_node.name, + relu_node.name], "Incorrect nodes in first fused operation" + assert [f.name for f in fused_operations[expected_op2_id]] == [linear_node.name, + softmax_node.name], "Incorrect nodes in second fused operation" + + +def test_fusing_info_node_mapping(mock_graph, fusing_info_generator, mock_nodes): + """ + Tests that each node is correctly mapped to its fused operation. + """ + fi = fusing_info_generator.generate_fusing_info(mock_graph) + node_to_fused_map = fi.get_node_to_fused_node_map() + + conv_node, relu_node, linear_node, softmax_node = mock_nodes + + expected_op1_id = f"{FUSED_OP_ID_PREFIX}conv_relu" + expected_op2_id = f"{FUSED_OP_ID_PREFIX}linear_softmax" + + assert node_to_fused_map[conv_node.name] == expected_op1_id + assert node_to_fused_map[relu_node.name] == expected_op1_id + assert node_to_fused_map[linear_node.name] == expected_op2_id + assert node_to_fused_map[softmax_node.name] == expected_op2_id + + +def test_fusing_info_validation(mock_graph, fusing_info_generator): + """ + Tests that the fusing info successfully validates a correct graph. + - If validation raises an error, the test fails. + """ + fi = fusing_info_generator.generate_fusing_info(mock_graph) + fi.validate(mock_graph) + + +def test_fusing_info_validation_failure_topology_change(mock_graph, fusing_info_generator, mock_nodes): + """ + Tests that validation fails when the graph topology is altered incorrectly. + - Adds an extra node, creating multiple successors for a node. + - Expects validation to fail with a ValueError. + """ + fusing_info = fusing_info_generator.generate_fusing_info(mock_graph) + extra_node = Mock(spec=BaseNode) + extra_node.name = 'extra_node_name' + + def modified_get_next_nodes(node): + if node == mock_nodes[0]: + return [mock_nodes[1], extra_node] # Conv now has two successors + return [] + + mock_graph.get_next_nodes.side_effect = modified_get_next_nodes + + with pytest.raises(ValueError): + fusing_info.validate(mock_graph) + +def test_add_fused_operation_adds_data(mock_graph, fusing_info_generator): + fi = FusingInfo() + node1 = MockBaseNode("a") + node2 = MockBaseNode("b") + op_id = f"{FUSED_OP_ID_PREFIX}a_b" + fi.add_fused_operation(op_id, (node1, node2)) + + assert op_id in fi.get_all_fused_operations() + assert fi.get_fused_node_name("a") == op_id + assert fi.get_fused_node_name("b") == op_id + +def test_remove_fused_operation_raises_for_missing_op(mock_graph, fusing_info_generator): + fi = FusingInfo() + with pytest.raises(ValueError, match="Fused operation __fused__missing does not exist"): + fi.remove_fused_operation("__fused__missing") + +def test_is_node_in_fused_op_returns_true_for_present_node(mock_graph, fusing_info_generator): + node1 = MockBaseNode("a") + node2 = MockBaseNode("b") + fi = FusingInfo(fusing_data={f"{FUSED_OP_ID_PREFIX}a_b": (node1, node2)}) + + assert fi.is_node_in_fused_op(node1) + assert fi.is_node_in_fused_op(node2) + +def test_is_node_in_fused_op_returns_false_for_absent_node(mock_graph, fusing_info_generator): + node1 = MockBaseNode("a") + node2 = MockBaseNode("b") + fi = FusingInfo(fusing_data={f"{FUSED_OP_ID_PREFIX}a_b": (node1, node2)}) + + unrelated = MockBaseNode("unrelated") + assert not fi.is_node_in_fused_op(unrelated) + diff --git a/tests_pytest/keras_tests/integration_tests/core/fusion/__init__.py b/tests_pytest/keras_tests/integration_tests/core/fusion/__init__.py new file mode 100644 index 000000000..5397dea24 --- /dev/null +++ b/tests_pytest/keras_tests/integration_tests/core/fusion/__init__.py @@ -0,0 +1,14 @@ +# Copyright 2025 Sony Semiconductor Israel, Inc. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== diff --git a/tests_pytest/keras_tests/integration_tests/core/fusion/test_fusing_info_generator_keras.py b/tests_pytest/keras_tests/integration_tests/core/fusion/test_fusing_info_generator_keras.py new file mode 100644 index 000000000..9c9706fb7 --- /dev/null +++ b/tests_pytest/keras_tests/integration_tests/core/fusion/test_fusing_info_generator_keras.py @@ -0,0 +1,368 @@ +# Copyright 2025 Sony Semiconductor Israel, Inc. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +import tensorflow as tf +from keras import Input +from tensorflow.keras import layers, Model + +from model_compression_toolkit.core import QuantizationConfig, CustomOpsetLayers +from model_compression_toolkit.core.common.fusion.fusing_info import FusingInfo +from tests_pytest._fw_tests_common_base.fusing.base_fusing_info_generator_test import BaseFusingInfoGeneratorTest, \ + random_activation_configs, get_activation_mp_options +from tests_pytest._test_util.graph_builder_utils import build_node +from tests_pytest.keras_tests.keras_test_util.keras_test_mixin import KerasFwMixin +import model_compression_toolkit.target_platform_capabilities.schema.mct_current_schema as schema + +from tensorflow.keras import backend as K + + +class BaseTestFusingInfoGeneratorKeras(BaseFusingInfoGeneratorTest, KerasFwMixin): + + K.clear_session() # Reset global layer naming to avoid name conflicts across tests + + def _data_gen(self): + return self.get_basic_data_gen(shapes=[(1, 16, 16, 3)])() + + def _get_qc(self): + qc = QuantizationConfig( + custom_tpc_opset_to_layer={"AnyAct": CustomOpsetLayers( + [layers.ReLU, layers.Activation, tf.nn.swish])}) + return qc + + + +class TestFusingConvRelu(BaseTestFusingInfoGeneratorKeras): + + last_node_activation_nbits, qcs = random_activation_configs() + + fusing_patterns = [ + schema.Fusing(operator_groups=( + schema.OperatorsSet(name=schema.OperatorSetNames.CONV), + schema.OperatorsSet(name=schema.OperatorSetNames.RELU))) + ] + + expected_fi = FusingInfo( + fusing_patterns=fusing_patterns, + fusing_data={ + "FusedNode_conv1_conv2_collapsed_relu": ( + build_node(name="conv1_conv2_collapsed"), + build_node(name="relu", qcs=qcs) + ) + } + ) + + def _get_tpc(self, default_quant_cfg_options): + conv = schema.OperatorsSet(name=schema.OperatorSetNames.CONV) + relu = schema.OperatorsSet(name=schema.OperatorSetNames.RELU, + qc_options=get_activation_mp_options(self.last_node_activation_nbits)) + return schema.TargetPlatformCapabilities( + default_qco=default_quant_cfg_options, + tpc_platform_type="test", + operator_set=[conv, relu], + fusing_patterns=self.fusing_patterns + ) + + def _get_model(self): + inputs = Input(shape=(None, None, 3)) + x = layers.Conv2D(16, kernel_size=(3, 3), padding='valid', name="conv1")(inputs) + x = layers.Conv2D(32, kernel_size=(1, 1), padding='valid', name="conv2")(x) + outputs = layers.ReLU(name="relu")(x) + return Model(inputs=inputs, outputs=outputs) + + +class TestFusingAnyActKeras(BaseTestFusingInfoGeneratorKeras): + + last_node_activation_nbits, qcs = random_activation_configs() + + fusing_patterns = [ + schema.Fusing(operator_groups=( + schema.OperatorsSet(name=schema.OperatorSetNames.CONV), + schema.OperatorsSet(name="AnyAct"))) + ] + + expected_fi = FusingInfo( + fusing_patterns=fusing_patterns, + fusing_data={ + "FusedNode_conv1_conv2_collapsed_tanh": + (build_node(name="conv1_conv2_collapsed"), + build_node(name="tanh", qcs=qcs)), + "FusedNode_conv3_relu": + (build_node(name="conv3"), + build_node(name="relu", qcs=qcs)), + "FusedNode_conv4_sigmoid": + (build_node(name="conv4"), + build_node(name="sigmoid", qcs=qcs)), + "FusedNode_conv5_tf.nn.silu": + (build_node(name="conv5"), + build_node(name="tf.nn.silu", qcs=qcs)), + } + ) + + def _get_tpc(self, default_quant_cfg_options): + conv = schema.OperatorsSet(name=schema.OperatorSetNames.CONV) + any_act = schema.OperatorsSet(name="AnyAct", + qc_options=get_activation_mp_options(self.last_node_activation_nbits)) + return schema.TargetPlatformCapabilities( + default_qco=default_quant_cfg_options, + tpc_platform_type="test", + operator_set=[conv, any_act], + fusing_patterns=self.fusing_patterns + ) + + def _get_model(self): + inputs = Input(shape=(32, 32, 3)) + x = layers.Conv2D(32, kernel_size=(3, 3), name="conv1")(inputs) + x = layers.Conv2D(32, kernel_size=(1, 1), name="conv2")(x) + x = layers.Activation("tanh", name="tanh")(x) + x = layers.Conv2D(32, kernel_size=(3, 3), name="conv3")(x) + x = layers.ReLU(name="relu")(x) + x = layers.Conv2D(64, kernel_size=(1, 1), name="conv4")(x) + x = layers.Activation("sigmoid", name="sigmoid")(x) + x = layers.Conv2D(64, kernel_size=(2, 2), name="conv5")(x) + outputs = layers.Activation("swish", name="tf.nn.silu")(x) + return Model(inputs=inputs, outputs=outputs) + + +class TestFusingConvReLUOnlyKeras(BaseTestFusingInfoGeneratorKeras): + + last_node_activation_nbits, qcs = random_activation_configs() + + fusing_patterns = [ + schema.Fusing(operator_groups=( + schema.OperatorsSet(name=schema.OperatorSetNames.CONV), + schema.OperatorsSet(name="AnyAct"))) + ] + + expected_fi = FusingInfo( + fusing_patterns=fusing_patterns, + fusing_data={ + "FusedNode_conv1_conv2_collapsed_tanh": + (build_node(name="conv1_conv2_collapsed"), + build_node(name="tanh", qcs=qcs)), + "FusedNode_conv3_relu": + (build_node(name="conv3"), + build_node(name="relu", qcs=qcs)), + "FusedNode_conv4_sigmoid": + (build_node(name="conv4"), + build_node(name="sigmoid", qcs=qcs)), + "FusedNode_conv5_swish": + (build_node(name="conv5"), + build_node(name="swish", qcs=qcs)), + } + ) + + def _get_tpc(self, default_quant_cfg_options): + conv = schema.OperatorsSet(name=schema.OperatorSetNames.CONV) + any_act = schema.OperatorsSet(name="AnyAct", + qc_options=get_activation_mp_options(self.last_node_activation_nbits)) + return schema.TargetPlatformCapabilities( + default_qco=default_quant_cfg_options, + tpc_platform_type="test", + operator_set=[conv, any_act], + fusing_patterns=self.fusing_patterns + ) + + def _get_model(self): + inputs = Input(shape=(32, 32, 3)) + x = layers.Conv2D(32, kernel_size=(3, 3), name="conv1")(inputs) + x = layers.Conv2D(32, kernel_size=(1, 1), name="conv2")(x) + x = layers.Activation("tanh", name="tanh")(x) + x = layers.Conv2D(32, kernel_size=(3, 3), name="conv3")(x) + x = layers.ReLU(name="relu")(x) + x = layers.Conv2D(64, kernel_size=(1, 1), name="conv4")(x) + x = layers.Activation("sigmoid", name="sigmoid")(x) + x = layers.Conv2D(64, kernel_size=(2, 2), name="conv5")(x) + outputs = layers.Activation(tf.nn.swish, name="swish")(x) + return Model(inputs=inputs, outputs=outputs) +class TestFusingComplexPatternsKeras(BaseTestFusingInfoGeneratorKeras): + + last_node_activation_nbits, qcs = random_activation_configs() + + fusing_patterns = [ + schema.Fusing(operator_groups=(schema.OperatorsSet(name=schema.OperatorSetNames.CONV), + schema.OperatorsSet(name=schema.OperatorSetNames.SWISH))), + schema.Fusing(operator_groups=(schema.OperatorsSet(name=schema.OperatorSetNames.CONV), + schema.OperatorsSet(name=schema.OperatorSetNames.ADD), + schema.OperatorsSet(name=schema.OperatorSetNames.SWISH))), + schema.Fusing(operator_groups=(schema.OperatorsSet(name=schema.OperatorSetNames.CONV), + schema.OperatorsSet(name=schema.OperatorSetNames.SWISH), + schema.OperatorsSet(name=schema.OperatorSetNames.ADD))), + schema.Fusing(operator_groups=(schema.OperatorsSet(name=schema.OperatorSetNames.FULLY_CONNECTED), + schema.OperatorsSet(name=schema.OperatorSetNames.SWISH))), + schema.Fusing(operator_groups=(schema.OperatorsSet(name=schema.OperatorSetNames.CONV), + schema.OperatorsSet(name=schema.OperatorSetNames.RELU))), + schema.Fusing(operator_groups=(schema.OperatorsSet(name=schema.OperatorSetNames.CONV), + schema.OperatorsSet(name=schema.OperatorSetNames.RELU), + schema.OperatorsSet(name=schema.OperatorSetNames.ADD))), + ] + + expected_fi = FusingInfo( + fusing_patterns=fusing_patterns, + fusing_data={ + "FusedNode_conv1_swish1_add": + (build_node(name="conv1"), + build_node(name="swish1"), + build_node(name="add", qcs=qcs)), + "FusedNode_conv2_swish2_add_1": + (build_node(name="conv2"), + build_node(name="swish2"), + build_node(name="add_1", qcs=qcs)), + "FusedNode_conv3_relu": + (build_node(name="conv3"), + build_node(name="relu", qcs=qcs)), + "FusedNode_conv4_relu_1_add_2": + (build_node(name="conv4"), + build_node(name="relu_1"), + build_node(name="add_2", qcs=qcs)), + "FusedNode_dense1_swish3": + (build_node(name="dense1"), + build_node(name="swish3", qcs=qcs)), + "FusedNode_dense2_swish4": + (build_node(name="dense2"), + build_node(name="swish4", qcs=qcs)), + } + ) + + def _get_tpc(self, default_quant_cfg_options): + opsets = [ + schema.OperatorsSet(name=schema.OperatorSetNames.CONV, + qc_options=get_activation_mp_options(self.last_node_activation_nbits)), + schema.OperatorsSet(name=schema.OperatorSetNames.ADD, + qc_options=get_activation_mp_options(self.last_node_activation_nbits)), + schema.OperatorsSet(name=schema.OperatorSetNames.RELU, + qc_options=get_activation_mp_options(self.last_node_activation_nbits)), + schema.OperatorsSet(name=schema.OperatorSetNames.SWISH, + qc_options=get_activation_mp_options(self.last_node_activation_nbits)), + schema.OperatorsSet(name=schema.OperatorSetNames.FULLY_CONNECTED, + qc_options=get_activation_mp_options(self.last_node_activation_nbits)), + ] + return schema.TargetPlatformCapabilities( + default_qco=default_quant_cfg_options, + tpc_platform_type="test", + operator_set=opsets, + fusing_patterns=self.fusing_patterns + ) + + def _get_model(self): + inputs = Input(shape=(32, 32, 3)) + + x = layers.Conv2D(3, (3, 3), padding='same', name="conv1")(inputs) + x = layers.Activation('swish', name="swish1")(x) + x = layers.Add(name="add")([x, inputs]) + + x2 = layers.Conv2D(3, (1, 1), padding='same', name="conv2")(x) + x2 = layers.Activation('swish', name="swish2")(x2) + x2 = layers.Add(name="add_1")([x, x2]) + + x3 = layers.Conv2D(3, (3, 3), padding='same', name="conv3")(x2) + x3 = layers.ReLU(name="relu")(x3) + + x4 = layers.Conv2D(3, (1, 1), padding='same', name="conv4")(x3) + x4 = layers.ReLU(name="relu_1")(x4) + x4 = layers.Add(name="add_2")([x4, x3]) + + x4 = layers.Flatten()(x4) + x4 = layers.Dense(16, name="dense1")(x4) + x4 = layers.Activation('swish', name="swish3")(x4) + + x4 = layers.Dense(16, name="dense2")(x4) + outputs = layers.Activation('swish', name="swish4")(x4) + + return Model(inputs=inputs, outputs=outputs) + +class TestFusingConvSwishWithMultiSuccessorsKeras(BaseTestFusingInfoGeneratorKeras): + + last_node_activation_nbits, qcs = random_activation_configs() + + fusing_patterns = [ + schema.Fusing(operator_groups=( + schema.OperatorsSet(name=schema.OperatorSetNames.CONV), + schema.OperatorsSet(name=schema.OperatorSetNames.SWISH))) + ] + + expected_fi = FusingInfo( + fusing_patterns=fusing_patterns, + fusing_data={ + "FusedNode_conv1_swish": ( + build_node(name="conv1"), + build_node(name="swish", qcs=qcs) + ) + } + ) + + def _get_tpc(self, default_quant_cfg_options): + conv = schema.OperatorsSet(name=schema.OperatorSetNames.CONV, + qc_options=get_activation_mp_options(self.last_node_activation_nbits)) + swish = schema.OperatorsSet(name=schema.OperatorSetNames.SWISH, + qc_options=get_activation_mp_options(self.last_node_activation_nbits)) + return schema.TargetPlatformCapabilities( + default_qco=default_quant_cfg_options, + tpc_platform_type="test", + operator_set=[conv, swish], + fusing_patterns=self.fusing_patterns + ) + + def _get_model(self): + inputs = Input(shape=(32, 32, 3)) + x = layers.Conv2D(16, (3, 3), padding='same', name="conv1")(inputs) + x = layers.Activation(tf.nn.swish, name="swish")(x) + + # Multiple successors of swish + branch1 = layers.Conv2D(8, (1, 1), name="branch1")(x) + branch2 = layers.Conv2D(8, (1, 1), name="branch2")(x) + outputs = layers.Add(name="add")([branch1, branch2]) + return Model(inputs=inputs, outputs=outputs) + +class TestFusingConvReluWithMultiPredecessorsKeras(BaseTestFusingInfoGeneratorKeras): + + last_node_activation_nbits, qcs = random_activation_configs() + + fusing_patterns = [ + schema.Fusing(operator_groups=( + schema.OperatorsSet(name=schema.OperatorSetNames.CONV), + schema.OperatorsSet(name=schema.OperatorSetNames.RELU))) + ] + + expected_fi = FusingInfo( + fusing_patterns=fusing_patterns, + fusing_data={ + "FusedNode_conv3_relu": ( + build_node(name="conv3"), + build_node(name="relu", qcs=qcs) + ) + } + ) + + def _get_tpc(self, default_quant_cfg_options): + conv = schema.OperatorsSet(name=schema.OperatorSetNames.CONV) + relu = schema.OperatorsSet(name=schema.OperatorSetNames.RELU, + qc_options=get_activation_mp_options(self.last_node_activation_nbits)) + return schema.TargetPlatformCapabilities( + default_qco=default_quant_cfg_options, + tpc_platform_type="test", + operator_set=[conv, relu], + fusing_patterns=self.fusing_patterns + ) + + def _get_model(self): + inputs = Input(shape=(32, 32, 3)) + x1 = layers.Conv2D(16, (3, 3), padding='same', name="conv1")(inputs) + x2 = layers.Conv2D(16, (3, 3), padding='same', name="conv2")(inputs) + + # Merge before relu + merged = layers.Add(name="merge")([x1, x2]) + x = layers.Conv2D(16, (3, 3), padding='same', name="conv3")(merged) + outputs = layers.ReLU(name="relu")(x) + return Model(inputs=inputs, outputs=outputs) + diff --git a/tests_pytest/keras_tests/integration_tests/core/fusion/test_graph_with_fusing_metadata_keras.py b/tests_pytest/keras_tests/integration_tests/core/fusion/test_graph_with_fusing_metadata_keras.py new file mode 100644 index 000000000..7c9ac09e1 --- /dev/null +++ b/tests_pytest/keras_tests/integration_tests/core/fusion/test_graph_with_fusing_metadata_keras.py @@ -0,0 +1,37 @@ +# Copyright 2025 Sony Semiconductor Israel, Inc. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + + +from tests_pytest._fw_tests_common_base.fusing.base_graph_with_fusing_metadata_test import BaseGraphWithFusingMetadataTest +from tests_pytest.keras_tests.keras_test_util.keras_test_mixin import KerasFwMixin + +import keras + +class TestGraphWithFusionMetadataKeras(BaseGraphWithFusingMetadataTest, KerasFwMixin): + + layer_class_relu = keras.layers.ReLU + + def _data_gen(self): + return self.get_basic_data_gen(shapes=[(1, 3, 5, 5)])() + + def _get_model(self): + model = keras.Sequential([ + keras.layers.Conv2D(3, (3, 3), activation=None, input_shape=(5, 5, 3), name='conv'), + keras.layers.ReLU(name='relu'), + keras.layers.Flatten(name='flatten'), + keras.layers.Dense(10, name='linear'), + keras.layers.Softmax(name='softmax') + ]) + return model diff --git a/tests_pytest/pytorch_tests/integration_tests/core/fusion/__init__.py b/tests_pytest/pytorch_tests/integration_tests/core/fusion/__init__.py new file mode 100644 index 000000000..5397dea24 --- /dev/null +++ b/tests_pytest/pytorch_tests/integration_tests/core/fusion/__init__.py @@ -0,0 +1,14 @@ +# Copyright 2025 Sony Semiconductor Israel, Inc. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== diff --git a/tests_pytest/pytorch_tests/integration_tests/core/fusion/test_fusing_info_generator_torch.py b/tests_pytest/pytorch_tests/integration_tests/core/fusion/test_fusing_info_generator_torch.py new file mode 100644 index 000000000..e89b9e91a --- /dev/null +++ b/tests_pytest/pytorch_tests/integration_tests/core/fusion/test_fusing_info_generator_torch.py @@ -0,0 +1,431 @@ +# Copyright 2025 Sony Semiconductor Israel, Inc. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +from unittest.mock import Mock + +import torch + +from model_compression_toolkit.core import QuantizationConfig, CustomOpsetLayers +from model_compression_toolkit.core.common.fusion.fusing_info import FusingInfo +from model_compression_toolkit.target_platform_capabilities import LayerFilterParams +from tests_pytest._fw_tests_common_base.fusing.base_fusing_info_generator_test import BaseFusingInfoGeneratorTest, \ + random_activation_configs, get_activation_mp_options +from tests_pytest._test_util.graph_builder_utils import build_node +from tests_pytest.pytorch_tests.torch_test_util.torch_test_mixin import TorchFwMixin + +import torch.nn as nn +import model_compression_toolkit.target_platform_capabilities.schema.mct_current_schema as schema + + + +class BaseTestFusingInfoGeneratorPytorch(BaseFusingInfoGeneratorTest, TorchFwMixin): + + def _data_gen(self): + return self.get_basic_data_gen(shapes=[(1, 3, 16, 16)])() + + def _get_qc(self): + qc = QuantizationConfig( + custom_tpc_opset_to_layer={"AnyAct": CustomOpsetLayers([nn.ReLU, nn.functional.relu6, nn.functional.relu, nn.SiLU, nn.Sigmoid, nn.Tanh, + LayerFilterParams(nn.Hardtanh, min_val=0)])}) + return qc + + + + +class TestFusingConvRelu(BaseTestFusingInfoGeneratorPytorch): + + last_node_activation_nbits, qcs = random_activation_configs() + + fusing_patterns = [ + schema.Fusing(operator_groups=( + schema.OperatorsSet(name=schema.OperatorSetNames.CONV), + schema.OperatorsSet(name=schema.OperatorSetNames.RELU))) + ] + + expected_fi = FusingInfo( + fusing_patterns=fusing_patterns, + fusing_data={ + "FusedNode_conv1_conv2_collapsed_relu": ( + build_node(name="conv1_conv2_collapsed"), + build_node(name="relu", qcs=qcs) + ) + } + ) + + def _get_tpc(self, default_quant_cfg_options): + conv = schema.OperatorsSet(name=schema.OperatorSetNames.CONV, + qc_options=get_activation_mp_options(self.last_node_activation_nbits)) + relu = schema.OperatorsSet(name=schema.OperatorSetNames.RELU, + qc_options=get_activation_mp_options(self.last_node_activation_nbits)) + return schema.TargetPlatformCapabilities( + default_qco=default_quant_cfg_options, + tpc_platform_type="test", + operator_set=[conv, relu], + fusing_patterns=self.fusing_patterns + ) + + def _get_model(self): + class Model(nn.Module): + def __init__(self): + super().__init__() + self.conv1 = nn.Conv2d(3, 16, kernel_size=(3, 3)) + self.conv2 = nn.Conv2d(16, 32, kernel_size=(1, 1)) + self.relu = nn.ReLU() + + def forward(self, x): + x = self.conv1(x) + x = self.conv2(x) + return self.relu(x) + + return Model() + +class TestFusingAnyAct(BaseTestFusingInfoGeneratorPytorch): + + last_node_activation_nbits, qcs = random_activation_configs() + + fusing_patterns = [ + schema.Fusing(operator_groups=( + schema.OperatorsSet(name=schema.OperatorSetNames.CONV), + schema.OperatorsSet(name="AnyAct"))) + ] + + expected_fi = FusingInfo( + fusing_patterns=fusing_patterns, + fusing_data={ + "FusedNode_conv1_conv2_collapsed_tanh": + (build_node(name="conv1_conv2_collapsed"), + build_node(name="tanh", qcs=qcs)), + "FusedNode_conv3_relu": + (build_node(name="conv3"), + build_node(name="relu", qcs=qcs)), + "FusedNode_conv4_sigmoid": + (build_node(name="conv4"), + build_node(name="sigmoid", qcs=qcs)), + "FusedNode_conv5_swish": + (build_node(name="conv5"), + build_node(name="swish", qcs=qcs)), + } + ) + + def _get_tpc(self, default_quant_cfg_options): + conv = schema.OperatorsSet(name=schema.OperatorSetNames.CONV) + any_act = schema.OperatorsSet(name="AnyAct", + qc_options=get_activation_mp_options(self.last_node_activation_nbits)) + return schema.TargetPlatformCapabilities( + default_qco=default_quant_cfg_options, + tpc_platform_type="test", + operator_set=[conv, any_act], + fusing_patterns=self.fusing_patterns + ) + + def _get_model(self): + class Model(nn.Module): + def __init__(self): + super().__init__() + self.conv1 = nn.Conv2d(3, 32, kernel_size=(3, 3)) + self.conv2 = nn.Conv2d(32, 32, kernel_size=(1, 1)) + self.conv3 = nn.Conv2d(32, 32, kernel_size=(3, 3)) + self.conv4 = nn.Conv2d(32, 64, kernel_size=(1, 1)) + self.conv5 = nn.Conv2d(64, 64, kernel_size=(2, 2)) + self.relu = nn.ReLU() + self.tanh = nn.Hardtanh(min_val=0) + self.swish = nn.SiLU() + self.sigmoid = nn.Sigmoid() + + def forward(self, x): + x = self.conv1(x) + x = self.conv2(x) + x = self.tanh(x) + x = self.conv3(x) + x = self.relu(x) + x = self.conv4(x) + x = self.sigmoid(x) + x = self.conv5(x) + return self.swish(x) + + return Model() + + +class TestFusingConvReLUOnly(BaseTestFusingInfoGeneratorPytorch): + + last_node_activation_nbits, qcs = random_activation_configs() + + fusing_patterns = [ + schema.Fusing(operator_groups=( + schema.OperatorsSet(name=schema.OperatorSetNames.CONV), + schema.OperatorsSet(name="AnyAct"))) + ] + + expected_fi = FusingInfo( + fusing_patterns=fusing_patterns, + fusing_data={ + "FusedNode_conv1_conv2_collapsed_tanh": + (build_node(name="conv1_conv2_collapsed"), build_node(name="tanh", qcs=qcs)), + "FusedNode_conv3_relu": + (build_node(name="conv3"), build_node(name="relu", qcs=qcs)), + "FusedNode_conv4_sigmoid": + (build_node(name="conv4"), build_node(name="sigmoid", qcs=qcs)), + "FusedNode_conv5_swish": + (build_node(name="conv5"), build_node(name="swish", qcs=qcs)) + } + ) + + def _get_tpc(self, default_quant_cfg_options): + conv = schema.OperatorsSet(name=schema.OperatorSetNames.CONV) + any_act = schema.OperatorsSet(name="AnyAct", + qc_options=get_activation_mp_options(self.last_node_activation_nbits)) + return schema.TargetPlatformCapabilities( + default_qco=default_quant_cfg_options, + tpc_platform_type="test", + operator_set=[conv, any_act], + fusing_patterns=self.fusing_patterns + ) + + def _get_model(self): + class Model(nn.Module): + def __init__(self): + super().__init__() + self.conv1 = nn.Conv2d(3, 32, kernel_size=(3, 3)) + self.conv2 = nn.Conv2d(32, 32, kernel_size=(1, 1)) + self.conv3 = nn.Conv2d(32, 32, kernel_size=(3, 3)) + self.conv4 = nn.Conv2d(32, 64, kernel_size=(1, 1)) + self.conv5 = nn.Conv2d(64, 64, kernel_size=(2, 2)) + self.relu = nn.ReLU() + self.tanh = nn.Tanh() + self.swish = nn.SiLU() + self.sigmoid = nn.Sigmoid() + + def forward(self, x): + x = self.conv1(x) + x = self.conv2(x) + x = self.tanh(x) + x = self.conv3(x) + x = self.relu(x) + x = self.conv4(x) + x = self.sigmoid(x) + x = self.conv5(x) + return self.swish(x) + + return Model() + + +class TestFusingComplexPatterns(BaseTestFusingInfoGeneratorPytorch): + + last_node_activation_nbits, qcs = random_activation_configs() + + fusing_patterns = [ + schema.Fusing(operator_groups=(schema.OperatorsSet(name=schema.OperatorSetNames.CONV), + schema.OperatorsSet(name=schema.OperatorSetNames.SWISH))), + schema.Fusing(operator_groups=(schema.OperatorsSet(name=schema.OperatorSetNames.CONV), + schema.OperatorsSet(name=schema.OperatorSetNames.ADD), + schema.OperatorsSet(name=schema.OperatorSetNames.SWISH))), + schema.Fusing(operator_groups=(schema.OperatorsSet(name=schema.OperatorSetNames.CONV), + schema.OperatorsSet(name=schema.OperatorSetNames.SWISH), + schema.OperatorsSet(name=schema.OperatorSetNames.ADD))), + schema.Fusing(operator_groups=(schema.OperatorsSet(name=schema.OperatorSetNames.FULLY_CONNECTED), + schema.OperatorsSet(name=schema.OperatorSetNames.SWISH))), + schema.Fusing(operator_groups=(schema.OperatorsSet(name=schema.OperatorSetNames.CONV), + schema.OperatorsSet(name=schema.OperatorSetNames.RELU))), + schema.Fusing(operator_groups=(schema.OperatorsSet(name=schema.OperatorSetNames.CONV), + schema.OperatorsSet(name=schema.OperatorSetNames.RELU), + schema.OperatorsSet(name=schema.OperatorSetNames.ADD))) + ] + + expected_fi = FusingInfo( + fusing_patterns=fusing_patterns, + fusing_data={ + "FusedNode_conv1_swish_add": ( + build_node(name="conv1"), + build_node(name="swish"), + build_node(name="add", qcs=qcs) + ), + "FusedNode_conv2_swish_1_add_1": ( + build_node(name="conv2"), + build_node(name="swish_1"), + build_node(name="add_1", qcs=qcs) + ), + "FusedNode_conv3_relu": ( + build_node(name="conv3"), + build_node(name="relu", qcs=qcs) + ), + "FusedNode_conv4_relu_1_add_2": ( + build_node(name="conv4"), + build_node(name="relu_1"), + build_node(name="add_2", qcs=qcs) + ), + "FusedNode_dense1_swish_2": ( + build_node(name="dense1"), + build_node(name="swish_2", qcs=qcs) + ), + "FusedNode_dense2_swish_3": ( + build_node(name="dense2"), + build_node(name="swish_3", qcs=qcs) + ), + } + ) + + def _get_tpc(self, default_quant_cfg_options): + opsets = [ + schema.OperatorsSet(name=schema.OperatorSetNames.CONV), + schema.OperatorsSet(name=schema.OperatorSetNames.ADD, + qc_options=get_activation_mp_options(self.last_node_activation_nbits)), + schema.OperatorsSet(name=schema.OperatorSetNames.RELU, + qc_options=get_activation_mp_options(self.last_node_activation_nbits)), + schema.OperatorsSet(name=schema.OperatorSetNames.SWISH, + qc_options=get_activation_mp_options(self.last_node_activation_nbits)), + schema.OperatorsSet(name=schema.OperatorSetNames.FULLY_CONNECTED, + qc_options=get_activation_mp_options(self.last_node_activation_nbits)), + ] + return schema.TargetPlatformCapabilities( + default_qco=default_quant_cfg_options, + tpc_platform_type="test", + operator_set=opsets, + fusing_patterns=self.fusing_patterns + ) + + def _get_model(self): + class Model(nn.Module): + def __init__(self): + super().__init__() + self.conv1 = nn.Conv2d(3, 3, kernel_size=(3, 3), padding='same') + self.conv2 = nn.Conv2d(3, 3, kernel_size=(1, 1), padding='same') + self.conv3 = nn.Conv2d(3, 3, kernel_size=(3, 3), padding='same') + self.conv4 = nn.Conv2d(3, 3, kernel_size=(1, 1), padding='same') + self.conv5 = nn.Conv2d(3, 3, kernel_size=(3, 3), padding='same') + self.conv6 = nn.Conv2d(3, 3, kernel_size=(1, 1), padding='same') + self.relu = nn.ReLU() + self.swish = nn.SiLU() + self.flatten = nn.Flatten() + self.dense1 = nn.Linear(768, out_features=16) + self.dense2 = nn.Linear(16, out_features=16) + + def forward(self, inputs): + x = self.conv1(inputs) + x = self.swish(x) + x1 = torch.add(x, inputs) + x2 = self.conv2(x1) + x2 = self.swish(x2) + x2 = torch.add(x1, x2) + x2 = self.conv3(x2) + x2 = self.relu(x2) + x3 = self.conv4(x2) + x3 = self.relu(x3) + x3 = torch.add(x3, x2) + x3 = self.flatten(x3) + x3 = self.dense1(x3) + x3 = self.swish(x3) + x3 = self.dense2(x3) + y = self.swish(x3) + return y + + return Model() + +class TestFusingConvSwishWithMultiSuccessors(BaseTestFusingInfoGeneratorPytorch): + + last_node_activation_nbits, qcs = random_activation_configs() + + fusing_patterns = [ + schema.Fusing(operator_groups=( + schema.OperatorsSet(name=schema.OperatorSetNames.CONV), + schema.OperatorsSet(name=schema.OperatorSetNames.SWISH))) + ] + + expected_fi = FusingInfo( + fusing_patterns=fusing_patterns, + fusing_data={ + "FusedNode_conv1_swish": ( + build_node(name="conv1"), + build_node(name="swish", qcs=qcs) + ) + } + ) + + def _get_tpc(self, default_quant_cfg_options): + conv = schema.OperatorsSet(name=schema.OperatorSetNames.CONV) + swish = schema.OperatorsSet(name=schema.OperatorSetNames.SWISH, + qc_options=get_activation_mp_options(self.last_node_activation_nbits)) + return schema.TargetPlatformCapabilities( + default_qco=default_quant_cfg_options, + tpc_platform_type="test", + operator_set=[conv, swish], + fusing_patterns=self.fusing_patterns + ) + + def _get_model(self): + class Model(nn.Module): + def __init__(self): + super().__init__() + self.conv1 = nn.Conv2d(3, 16, kernel_size=3, padding=1) + self.swish = nn.SiLU() + self.branch1 = nn.Conv2d(16, 8, kernel_size=1) + self.branch2 = nn.Conv2d(16, 8, kernel_size=1) + + def forward(self, x): + x = self.conv1(x) + x = self.swish(x) + b1 = self.branch1(x) + b2 = self.branch2(x) + return b1 + b2 + + return Model() + +class TestFusingConvReluWithMultiPredecessors(BaseTestFusingInfoGeneratorPytorch): + + last_node_activation_nbits, qcs = random_activation_configs() + + fusing_patterns = [ + schema.Fusing(operator_groups=( + schema.OperatorsSet(name=schema.OperatorSetNames.CONV), + schema.OperatorsSet(name=schema.OperatorSetNames.RELU))) + ] + + expected_fi = FusingInfo( + fusing_patterns=fusing_patterns, + fusing_data={ + "FusedNode_conv3_relu": ( + build_node(name="conv3"), + build_node(name="relu", qcs=qcs) + ) + } + ) + + def _get_tpc(self, default_quant_cfg_options): + conv = schema.OperatorsSet(name=schema.OperatorSetNames.CONV) + relu = schema.OperatorsSet(name=schema.OperatorSetNames.RELU, + qc_options=get_activation_mp_options(self.last_node_activation_nbits)) + return schema.TargetPlatformCapabilities( + default_qco=default_quant_cfg_options, + tpc_platform_type="test", + operator_set=[conv, relu], + fusing_patterns=self.fusing_patterns + ) + + def _get_model(self): + class Model(nn.Module): + def __init__(self): + super().__init__() + self.conv1 = nn.Conv2d(3, 16, kernel_size=3, padding=1) + self.conv2 = nn.Conv2d(3, 16, kernel_size=3, padding=1) + self.conv3 = nn.Conv2d(16, 16, kernel_size=3, padding=1) + self.relu = nn.ReLU() + + def forward(self, x): + x1 = self.conv1(x) + x2 = self.conv2(x) + merged = x1 + x2 + x3 = self.conv3(merged) + return self.relu(x3) + + return Model() + diff --git a/tests_pytest/pytorch_tests/integration_tests/core/fusion/test_graph_with_fusing_metadata_torch.py b/tests_pytest/pytorch_tests/integration_tests/core/fusion/test_graph_with_fusing_metadata_torch.py new file mode 100644 index 000000000..72a1e4dd9 --- /dev/null +++ b/tests_pytest/pytorch_tests/integration_tests/core/fusion/test_graph_with_fusing_metadata_torch.py @@ -0,0 +1,48 @@ +# Copyright 2025 Sony Semiconductor Israel, Inc. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + +from tests_pytest._fw_tests_common_base.fusing.base_graph_with_fusing_metadata_test import BaseGraphWithFusingMetadataTest +from tests_pytest.pytorch_tests.torch_test_util.torch_test_mixin import TorchFwMixin + +import torch.nn as nn + +class TestGraphWithFusionMetadataPytorch(BaseGraphWithFusingMetadataTest, TorchFwMixin): + + layer_class_relu = nn.ReLU + + def _data_gen(self): + return self.get_basic_data_gen(shapes=[(1, 3, 5, 5)])() + + def _get_model(self): + class Model(nn.Module): + def __init__(self): + super().__init__() + self.conv = nn.Conv2d(3, 3, kernel_size=(3, 3)) + self.relu = nn.ReLU() + self.flatten = nn.Flatten() + self.linear = nn.Linear(in_features=27, out_features=10) + self.softmax = nn.Softmax() + + def forward(self, x): + x = self.conv(x) + x = self.relu(x) + x = self.flatten(x) + x = self.linear(x) + x = self.softmax(x) + return x + + return Model() + +