-
Notifications
You must be signed in to change notification settings - Fork 79
Refactor fusing #1386
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Refactor fusing #1386
Changes from 18 commits
1cc4739
a2e4632
1485cdc
e08676f
bb8cf01
0bb6f5a
d7135d6
8f3b1d9
4cc1533
3f6fe40
2ea3f2a
595f4bd
a9ae9f9
e1cad46
43fbf1d
5c16f03
7c52939
7619677
8406b2c
fb0a30b
ec583f1
44f2256
5544545
6e7d170
1ff0cc9
c1a413b
ce79fc8
459394e
671c58a
34c2820
0054571
bd75774
b1be93a
a589333
30e3645
e7ac1de
f5c0a07
7551d34
98c14f5
8ee3191
d78ac46
4c396ae
4580a98
3ab79d9
63a7dff
c170f2a
773c115
d31c53c
106e694
a16e47a
f485a36
8a0334b
68b560c
d097e77
2011098
48ab653
e112c65
b640c45
39dcce1
57320b0
3031dba
e81ebf5
7000c8d
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Large diffs are not rendered by default.
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -1,4 +1,4 @@ | ||
| # Copyright 2024 Sony Semiconductor Israel, Inc. All rights reserved. | ||
| # Copyright 2025 Sony Semiconductor Israel, Inc. All rights reserved. | ||
| # | ||
| # Licensed under the Apache License, Version 2.0 (the "License"); | ||
| # you may not use this file except in compliance with the License. | ||
|
|
@@ -13,10 +13,13 @@ | |
| # limitations under the License. | ||
| # ============================================================================== | ||
|
|
||
| from typing import Dict, List | ||
| import copy | ||
| from typing import List | ||
|
|
||
| from model_compression_toolkit.core.common import Graph, BaseNode | ||
| from model_compression_toolkit.core.common.graph.base_graph import OutTensor | ||
| from model_compression_toolkit.core.common.fusion.graph_with_fusing_metadata import GraphWithFusingMetadata | ||
| from model_compression_toolkit.core.common.graph.base_graph import Graph, BaseNode, OutTensor | ||
| from model_compression_toolkit.core.common.quantization.candidate_node_quantization_config import CandidateNodeQuantizationConfig | ||
| from itertools import product | ||
|
|
||
|
|
||
| class FusedLayerType: | ||
|
|
@@ -27,35 +30,31 @@ class FusedLayerType: | |
| def __init__(self): | ||
| self.__name__ = 'FusedLayer' | ||
|
|
||
|
|
||
| class GraphFuser: | ||
|
|
||
| def create_fused_graph(self, graph: Graph) -> Dict[str, str]: | ||
| def fuse(self, fused_graph: GraphWithFusingMetadata): | ||
| """ | ||
| GraphFuser is responsible for fusing nodes in a networkx graph. | ||
| The fusion process involves: | ||
| 1. Creating new fused nodes to represent these groups. | ||
| 2. Updating the graph structure to replace the original nodes with fused nodes. | ||
| 3. Maintaining mapping of original node names to their fused node names. | ||
|
|
||
| Args: | ||
| graph: Graph to fuse its nodes. | ||
|
|
||
| Returns: | ||
| Mapping of original node names to their fused node names | ||
| """ | ||
| fused_nodes_mapping = {} | ||
| # Iterate through each group of nodes to be fused | ||
| for fused_nodes_list in graph.fused_nodes: | ||
| new_fused_node = self._create_fused_node(fused_nodes_list) | ||
| self._replace_nodes_with_fused_node(graph, fused_nodes_list, new_fused_node) | ||
| # Update the mapping to keep track of which original nodes are now part of which fused nodes | ||
| for node in fused_nodes_list: | ||
| fused_nodes_mapping[node.name] = new_fused_node.name | ||
| return fused_nodes_mapping | ||
| graph = copy.deepcopy(fused_graph) # this will be the new fused graph | ||
reuvenperetz marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
| for fused_node_id, fused_nodes_list in graph.get_fusing_info().get_all_fused_operations().items(): | ||
| new_fused_node = self._create_fused_node(fused_node_id, fused_nodes_list) | ||
| new_fused_nodes_list = [graph.get_internal_graph().find_node_by_name(n.name)[0] for n in fused_nodes_list] | ||
| self._replace_nodes_with_fused_node(graph.get_internal_graph(), new_fused_nodes_list, new_fused_node) | ||
| return graph.get_internal_graph() | ||
|
|
||
|
|
||
| @staticmethod | ||
| def _create_fused_node(nodes: List[BaseNode]) -> BaseNode: | ||
| def _create_fused_node(fused_node_id: str, nodes: List[BaseNode]) -> BaseNode: | ||
| """ | ||
| Create a new node that represents the fusion of the given nodes. | ||
|
|
||
|
|
@@ -67,15 +66,29 @@ def _create_fused_node(nodes: List[BaseNode]) -> BaseNode: | |
| """ | ||
| # Create a new node with a name that reflects its components | ||
| # Use the input shape of the first node and output shape of the last node | ||
| fused_node = BaseNode(name='FusedNode_' + '_'.join([node.name for node in nodes]), | ||
| fused_node_name = fused_node_id | ||
reuvenperetz marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
|
|
||
| # TODO: consider replacing the fused node with a sub-model to allow inference on it, etc. | ||
| fused_node = BaseNode(name=fused_node_name, | ||
| framework_attr={}, | ||
| input_shape=nodes[0].input_shape, | ||
| output_shape=nodes[-1].output_shape, | ||
| weights={}, | ||
| weights={}, # TODO: update with weights of all nodes | ||
|
||
| layer_class=FusedLayerType) | ||
|
|
||
| # Preserve the final activation quantization configuration | ||
| # This is important for maintaining the correct behavior of the fused node | ||
| # Create candidates for this node (we assume that the weights configuration should be taken from the first node, and the activaion configuration | ||
reuvenperetz marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
| # is the output quantization configuration of the last node. We ignore all configurations of middle nodes. | ||
| weight_cfgs = [c.weights_quantization_cfg for c in nodes[0].candidates_quantization_cfg] | ||
| activation_cfgs = [c.activation_quantization_cfg for c in nodes[-1].candidates_quantization_cfg] | ||
| if weight_cfgs and activation_cfgs: | ||
| combinations = list(product(weight_cfgs, activation_cfgs)) | ||
| fused_node.candidates_quantization_cfg = [ | ||
| CandidateNodeQuantizationConfig(weights_quantization_cfg=w, activation_quantization_cfg=a) | ||
| for w, a in combinations | ||
| ] | ||
|
|
||
| # Keep the final configurations if they were set already. | ||
| fused_node.final_weights_quantization_cfg = nodes[0].final_weights_quantization_cfg | ||
| fused_node.final_activation_quantization_cfg = nodes[-1].final_activation_quantization_cfg | ||
|
|
||
| return fused_node | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,170 @@ | ||
| # Copyright 2025 Sony Semiconductor Israel, Inc. All rights reserved. | ||
| # | ||
| # Licensed under the Apache License, Version 2.0 (the "License"); | ||
| # you may not use this file except in compliance with the License. | ||
| # You may obtain a copy of the License at | ||
| # | ||
| # http://www.apache.org/licenses/LICENSE-2.0 | ||
| # | ||
| # Unless required by applicable law or agreed to in writing, software | ||
| # distributed under the License is distributed on an "AS IS" BASIS, | ||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
| # See the License for the specific language governing permissions and | ||
| # limitations under the License. | ||
| # ============================================================================== | ||
|
|
||
| import types | ||
|
|
||
| from functools import wraps | ||
|
|
||
| from typing import Any, Iterator | ||
|
|
||
| from model_compression_toolkit.core.common import BaseNode, Graph | ||
| from model_compression_toolkit.core.common.fusion.fusing_info import FusingInfo | ||
|
|
||
|
|
||
| class FusedLayerType: | ||
| """ | ||
| Used to represent the type of fused layers, since __name__ | ||
| is accessed when the graph is displayed. | ||
| """ | ||
| def __init__(self): | ||
| self.__name__ = 'FusedLayer' | ||
reuvenperetz marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
|
|
||
|
|
||
| class GraphWithFusingMetadata: | ||
reuvenperetz marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
| def __init__(self, graph: Graph, fusing_info: FusingInfo): | ||
| """ | ||
| Initialize with a graph and its fusing information. | ||
|
|
||
| Args: | ||
| graph: The neural network graph (e.g., a networkx.DiGraph or similar). | ||
| fusing_info: Dict mapping fused operation IDs to sets of node objects. | ||
| """ | ||
| assert isinstance(graph, Graph) | ||
| self._internal_graph = graph | ||
| self._fusing_info = fusing_info | ||
| self._fusing_info.validate(graph) # Ensure initial consistency | ||
| # TODO: temp disable activation quantization to keep similar functionality. This will be removed in the future | ||
| self._disable_nodes_activation_quantization() | ||
|
|
||
| # We added __getstate__ and __setstate__ to FusedGraph to fix a recursion error during copy.deepcopy. Without | ||
reuvenperetz marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
| # these, deepcopy endlessly traverses attributes via __getattr__, causing a loop. Now, __getstate__ defines what | ||
| # to copy (self._graph and self._fusing_info), and __setstate__ rebuilds the object, ensuring a clean copy | ||
| # without recursion, assuming Graph and FusingInfo are copyable. | ||
| def __getstate__(self): | ||
| """ | ||
| Define how the object is serialized for copying. | ||
| Returns a dictionary of the essential attributes. | ||
| """ | ||
| self._fusing_info.validate(self._internal_graph) | ||
| return self.__dict__.copy() | ||
|
|
||
| def __setstate__(self, state): | ||
| """ | ||
| Reconstruct the object from the serialized state. | ||
|
|
||
| Args: | ||
| state: Dictionary containing the serialized attributes. | ||
| """ | ||
| self.__dict__.update(state) | ||
| self._fusing_info.validate(self._internal_graph) | ||
|
|
||
| def __getattr__(self, name: str) -> Any: | ||
| """ | ||
| Delegate attribute access to the underlying graph if not found in FusedGraph. | ||
|
|
||
| Ensures that if the accessed attribute is a callable (e.g., a method like remove_node), | ||
| it is wrapped so that the fusing information is validated after execution. | ||
| Non-callable attributes are returned directly without validation. | ||
|
|
||
| Args: | ||
| name: The name of the attribute being accessed. | ||
|
|
||
| Returns: | ||
| The attribute or a wrapped method from self._graph. | ||
|
|
||
| Raises: | ||
| AttributeError: If the attribute doesn't exist in self._graph. | ||
| """ | ||
|
|
||
| # TODO: Optimize validation by restricting it to known modifying methods to improve efficiency. For now, | ||
| # validating after every method call ensures correctness. In the | ||
| # future, define explicit modification methods (e.g., remove_node) | ||
| # in FusedGraph for better efficiency. | ||
|
|
||
| graph_attr = getattr(self._internal_graph, name) | ||
| # Only wrap methods or functions, excluding properties and descriptors | ||
| if isinstance(graph_attr, (types.MethodType, types.FunctionType)): | ||
| @wraps(graph_attr) | ||
| def wrapper(*args, **kwargs): | ||
| result = graph_attr(*args, **kwargs) | ||
| self._fusing_info.validate(self._internal_graph) | ||
| return result | ||
| return wrapper | ||
|
|
||
| return graph_attr | ||
|
|
||
| def __iter__(self) -> Iterator[BaseNode]: | ||
| """ | ||
| Make FusedGraph iterable by delegating to the underlying graph's iterator. | ||
|
|
||
| This allows FusedGraph to be used in contexts expecting an iterable of nodes, | ||
| such as topological_sort, without requiring changes to external code. | ||
|
|
||
| Returns: | ||
| An iterator over the nodes in the underlying graph. | ||
| """ | ||
| return iter(self._internal_graph) | ||
|
|
||
| def __getitem__(self, key: Any) -> Any: | ||
| """ | ||
| Delegate subscripting to the underlying graph. | ||
|
|
||
| This enables FusedGraph to support dictionary-like access (e.g., graph[node][child]) | ||
| as required by operations like topological_generations in NetworkX, maintaining | ||
| compatibility with code expecting a subscriptable Graph object. | ||
|
|
||
| Args: | ||
| key: The key (e.g., node) to look up in the graph. | ||
|
|
||
| Returns: | ||
| The value associated with the key in the underlying graph. | ||
|
|
||
| Raises: | ||
| KeyError: If the key doesn't exist in self._graph. | ||
| """ | ||
| return self._internal_graph[key] | ||
|
|
||
| def update_fusing_info(self, new_fusing_info: FusingInfo): | ||
| self._fusing_info = new_fusing_info | ||
|
|
||
| def get_internal_graph(self): | ||
| """Return the original graph.""" | ||
| return self._internal_graph | ||
|
|
||
| def get_fusing_info(self): | ||
| """Return the fusing information.""" | ||
| return self._fusing_info | ||
|
|
||
| def is_part_of_fused_op(self, node): | ||
| """Check if a node is part of any fused operation.""" | ||
| return self._fusing_info.is_node_in_fused_op(node) | ||
|
|
||
| def _disable_nodes_activation_quantization(self): | ||
| """ | ||
| Disable activation for non-quantization needed due to fusion | ||
| Args: | ||
| nodes: nodes to update their activation quantization | ||
| """ | ||
| # TODO: temp disable activation quantization to keep similar functionality. This will be removed in the future | ||
| nodes_to_disable = self._fusing_info.get_nodes_to_disable_act_quantization() | ||
| for node in nodes_to_disable: | ||
| for qc in node.candidates_quantization_cfg: | ||
| qc.activation_quantization_cfg.enable_activation_quantization = False | ||
|
|
||
| def validate(self): | ||
| """ | ||
| Check if the internal graph and fusing data are consistent. | ||
| """ | ||
| return self._fusing_info.validate(self._internal_graph) | ||
Uh oh!
There was an error while loading. Please reload this page.