Skip to content
Merged
Show file tree
Hide file tree
Changes from 18 commits
Commits
Show all changes
63 commits
Select commit Hold shift + click to select a range
1cc4739
Fusing refactor
Mar 10, 2025
a2e4632
remove commented out test file
Mar 12, 2025
1485cdc
revert changes in keras mixin
Mar 12, 2025
e08676f
remove unneeded fusion tests packages hirarchy
Mar 12, 2025
bb8cf01
revert old tests that were commented out
Mar 12, 2025
0bb6f5a
fix wrong type hint in test_activation_weights_composition_substition
Mar 12, 2025
d7135d6
remove commented out code from runner
Mar 12, 2025
8f3b1d9
use internal graph for the final qat model instead of the fused graph
Mar 12, 2025
4cc1533
add comment for second moment correction about the fusing info correc…
Mar 12, 2025
3f6fe40
remove old fusing data from graph
Mar 12, 2025
2ea3f2a
update old pytorch tests
Mar 12, 2025
595f4bd
add check for graph type in torch model builder
Mar 13, 2025
a9ae9f9
add comments to fusing info
Mar 14, 2025
e1cad46
add comments to graph fuser and graph with metadata
Mar 14, 2025
43fbf1d
Set version for onnxruntime-extensions
Mar 14, 2025
5c16f03
adapt keras unit tests
Mar 14, 2025
7c52939
Revert "Set version for onnxruntime-extensions"
Mar 17, 2025
7619677
Merge remote-tracking branch 'origin/main' into refactor-fusing
Mar 17, 2025
8406b2c
fix comments in fusing info
Mar 17, 2025
fb0a30b
remove the deepcopy in get_all_fused_operations
Mar 17, 2025
ec583f1
verify fusing info is consistent when using graph fuser
Mar 18, 2025
44f2256
move function to disable activation quantization from fusing info to …
Mar 18, 2025
5544545
use prefic of op id as constant
Mar 18, 2025
6e7d170
pass only fusing patterns instead of entire fqc
Mar 18, 2025
1ff0cc9
use dataclass for FusingInfo
Mar 18, 2025
c1a413b
rename filter_fusing_patterns
Mar 24, 2025
ce79fc8
remove duplicate def of FusedLayerType
Mar 24, 2025
459394e
rename fusing mateadata wrapper
Mar 24, 2025
671c58a
fix old name of FusingMetadataWrapper in comments
Mar 24, 2025
34c2820
fix comments in fuse method of GraphFuser
Mar 24, 2025
0054571
rename fuse method in graph fuser
Mar 24, 2025
bd75774
save fusing info instead of multiple fetches
Mar 24, 2025
b1be93a
set the graph without the fusing metadata in the pytorch back2fw
Mar 24, 2025
a589333
add missing types in BaseGraphWithFusingMetadataTest
Mar 24, 2025
30e3645
adjust the test of activation-weight composition due to the fusion re…
Mar 24, 2025
e7ac1de
migrate old unittests
Mar 25, 2025
f5c0a07
tests with multiple successors/predecessors
Mar 26, 2025
7551d34
add test that checks the case of new fusion due to new node
Mar 26, 2025
98c14f5
add test that checks a valid graph change does not fail the validatio…
Mar 26, 2025
8ee3191
merge changes with main
Mar 26, 2025
d78ac46
move tests to a designated package
Mar 26, 2025
4c396ae
fix batchnorm_reconstruction due to replacing list with tuple in fus…
Mar 26, 2025
4580a98
use internal graph for mp for now
Mar 26, 2025
3ab79d9
fix missing import of removed test
Mar 26, 2025
63a7dff
fix wrong path in test_cfg_candidates_filter
Mar 26, 2025
c170f2a
run pytest before unittests
Mar 26, 2025
773c115
add missing license
Mar 26, 2025
d31c53c
extract minimal_cfg_options from minimal_tpc
Mar 26, 2025
106e694
fix new argument in test_cfg_candidates_filter
Mar 27, 2025
a16e47a
replace wrapper with embedding the fusing info in the graph
Apr 1, 2025
f485a36
rewrite the funsions in test_fusing_info more explicity
Apr 1, 2025
8a0334b
return fusion to test in test_activation_weights_composition_substitu…
Apr 1, 2025
68b560c
disable validation in tests of virtual graph
Apr 1, 2025
d097e77
fix tests
Apr 1, 2025
2011098
merge from main
Apr 1, 2025
48ab653
rename tests files to remove wrapper
Apr 1, 2025
e112c65
run pytests in keras before unittests
Apr 1, 2025
b640c45
add integration tests for GraphFuser
Apr 2, 2025
39dcce1
disable validation in keras function tests for virtual graph
Apr 2, 2025
57320b0
use tuple for the op cfg options instead of list
Apr 2, 2025
3031dba
fix tpc in tests where default options had more than one option
Apr 2, 2025
e81ebf5
extend fusing info tests
Apr 2, 2025
7000c8d
change skip_validation flag to protected
Apr 2, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
434 changes: 434 additions & 0 deletions model_compression_toolkit/core/common/fusion/fusing_info.py

Large diffs are not rendered by default.

55 changes: 34 additions & 21 deletions model_compression_toolkit/core/common/fusion/graph_fuser.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Copyright 2024 Sony Semiconductor Israel, Inc. All rights reserved.
# Copyright 2025 Sony Semiconductor Israel, Inc. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
Expand All @@ -13,10 +13,13 @@
# limitations under the License.
# ==============================================================================

from typing import Dict, List
import copy
from typing import List

from model_compression_toolkit.core.common import Graph, BaseNode
from model_compression_toolkit.core.common.graph.base_graph import OutTensor
from model_compression_toolkit.core.common.fusion.graph_with_fusing_metadata import GraphWithFusingMetadata
from model_compression_toolkit.core.common.graph.base_graph import Graph, BaseNode, OutTensor
from model_compression_toolkit.core.common.quantization.candidate_node_quantization_config import CandidateNodeQuantizationConfig
from itertools import product


class FusedLayerType:
Expand All @@ -27,35 +30,31 @@ class FusedLayerType:
def __init__(self):
self.__name__ = 'FusedLayer'


class GraphFuser:

def create_fused_graph(self, graph: Graph) -> Dict[str, str]:
def fuse(self, fused_graph: GraphWithFusingMetadata):
"""
GraphFuser is responsible for fusing nodes in a networkx graph.
The fusion process involves:
1. Creating new fused nodes to represent these groups.
2. Updating the graph structure to replace the original nodes with fused nodes.
3. Maintaining mapping of original node names to their fused node names.

Args:
graph: Graph to fuse its nodes.

Returns:
Mapping of original node names to their fused node names
"""
fused_nodes_mapping = {}
# Iterate through each group of nodes to be fused
for fused_nodes_list in graph.fused_nodes:
new_fused_node = self._create_fused_node(fused_nodes_list)
self._replace_nodes_with_fused_node(graph, fused_nodes_list, new_fused_node)
# Update the mapping to keep track of which original nodes are now part of which fused nodes
for node in fused_nodes_list:
fused_nodes_mapping[node.name] = new_fused_node.name
return fused_nodes_mapping
graph = copy.deepcopy(fused_graph) # this will be the new fused graph
for fused_node_id, fused_nodes_list in graph.get_fusing_info().get_all_fused_operations().items():
new_fused_node = self._create_fused_node(fused_node_id, fused_nodes_list)
new_fused_nodes_list = [graph.get_internal_graph().find_node_by_name(n.name)[0] for n in fused_nodes_list]
self._replace_nodes_with_fused_node(graph.get_internal_graph(), new_fused_nodes_list, new_fused_node)
return graph.get_internal_graph()


@staticmethod
def _create_fused_node(nodes: List[BaseNode]) -> BaseNode:
def _create_fused_node(fused_node_id: str, nodes: List[BaseNode]) -> BaseNode:
"""
Create a new node that represents the fusion of the given nodes.

Expand All @@ -67,15 +66,29 @@ def _create_fused_node(nodes: List[BaseNode]) -> BaseNode:
"""
# Create a new node with a name that reflects its components
# Use the input shape of the first node and output shape of the last node
fused_node = BaseNode(name='FusedNode_' + '_'.join([node.name for node in nodes]),
fused_node_name = fused_node_id

# TODO: consider replacing the fused node with a sub-model to allow inference on it, etc.
fused_node = BaseNode(name=fused_node_name,
framework_attr={},
input_shape=nodes[0].input_shape,
output_shape=nodes[-1].output_shape,
weights={},
weights={}, # TODO: update with weights of all nodes
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is the todo here planned for this PR?
is it necessary actually? because you can always retrieve the original weights from the original graph

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No, in the PR that handles the MP.

layer_class=FusedLayerType)

# Preserve the final activation quantization configuration
# This is important for maintaining the correct behavior of the fused node
# Create candidates for this node (we assume that the weights configuration should be taken from the first node, and the activaion configuration
# is the output quantization configuration of the last node. We ignore all configurations of middle nodes.
weight_cfgs = [c.weights_quantization_cfg for c in nodes[0].candidates_quantization_cfg]
activation_cfgs = [c.activation_quantization_cfg for c in nodes[-1].candidates_quantization_cfg]
if weight_cfgs and activation_cfgs:
combinations = list(product(weight_cfgs, activation_cfgs))
fused_node.candidates_quantization_cfg = [
CandidateNodeQuantizationConfig(weights_quantization_cfg=w, activation_quantization_cfg=a)
for w, a in combinations
]

# Keep the final configurations if they were set already.
fused_node.final_weights_quantization_cfg = nodes[0].final_weights_quantization_cfg
fused_node.final_activation_quantization_cfg = nodes[-1].final_activation_quantization_cfg

return fused_node
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,170 @@
# Copyright 2025 Sony Semiconductor Israel, Inc. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================

import types

from functools import wraps

from typing import Any, Iterator

from model_compression_toolkit.core.common import BaseNode, Graph
from model_compression_toolkit.core.common.fusion.fusing_info import FusingInfo


class FusedLayerType:
"""
Used to represent the type of fused layers, since __name__
is accessed when the graph is displayed.
"""
def __init__(self):
self.__name__ = 'FusedLayer'


class GraphWithFusingMetadata:
def __init__(self, graph: Graph, fusing_info: FusingInfo):
"""
Initialize with a graph and its fusing information.

Args:
graph: The neural network graph (e.g., a networkx.DiGraph or similar).
fusing_info: Dict mapping fused operation IDs to sets of node objects.
"""
assert isinstance(graph, Graph)
self._internal_graph = graph
self._fusing_info = fusing_info
self._fusing_info.validate(graph) # Ensure initial consistency
# TODO: temp disable activation quantization to keep similar functionality. This will be removed in the future
self._disable_nodes_activation_quantization()

# We added __getstate__ and __setstate__ to FusedGraph to fix a recursion error during copy.deepcopy. Without
# these, deepcopy endlessly traverses attributes via __getattr__, causing a loop. Now, __getstate__ defines what
# to copy (self._graph and self._fusing_info), and __setstate__ rebuilds the object, ensuring a clean copy
# without recursion, assuming Graph and FusingInfo are copyable.
def __getstate__(self):
"""
Define how the object is serialized for copying.
Returns a dictionary of the essential attributes.
"""
self._fusing_info.validate(self._internal_graph)
return self.__dict__.copy()

def __setstate__(self, state):
"""
Reconstruct the object from the serialized state.

Args:
state: Dictionary containing the serialized attributes.
"""
self.__dict__.update(state)
self._fusing_info.validate(self._internal_graph)

def __getattr__(self, name: str) -> Any:
"""
Delegate attribute access to the underlying graph if not found in FusedGraph.

Ensures that if the accessed attribute is a callable (e.g., a method like remove_node),
it is wrapped so that the fusing information is validated after execution.
Non-callable attributes are returned directly without validation.

Args:
name: The name of the attribute being accessed.

Returns:
The attribute or a wrapped method from self._graph.

Raises:
AttributeError: If the attribute doesn't exist in self._graph.
"""

# TODO: Optimize validation by restricting it to known modifying methods to improve efficiency. For now,
# validating after every method call ensures correctness. In the
# future, define explicit modification methods (e.g., remove_node)
# in FusedGraph for better efficiency.

graph_attr = getattr(self._internal_graph, name)
# Only wrap methods or functions, excluding properties and descriptors
if isinstance(graph_attr, (types.MethodType, types.FunctionType)):
@wraps(graph_attr)
def wrapper(*args, **kwargs):
result = graph_attr(*args, **kwargs)
self._fusing_info.validate(self._internal_graph)
return result
return wrapper

return graph_attr

def __iter__(self) -> Iterator[BaseNode]:
"""
Make FusedGraph iterable by delegating to the underlying graph's iterator.

This allows FusedGraph to be used in contexts expecting an iterable of nodes,
such as topological_sort, without requiring changes to external code.

Returns:
An iterator over the nodes in the underlying graph.
"""
return iter(self._internal_graph)

def __getitem__(self, key: Any) -> Any:
"""
Delegate subscripting to the underlying graph.

This enables FusedGraph to support dictionary-like access (e.g., graph[node][child])
as required by operations like topological_generations in NetworkX, maintaining
compatibility with code expecting a subscriptable Graph object.

Args:
key: The key (e.g., node) to look up in the graph.

Returns:
The value associated with the key in the underlying graph.

Raises:
KeyError: If the key doesn't exist in self._graph.
"""
return self._internal_graph[key]

def update_fusing_info(self, new_fusing_info: FusingInfo):
self._fusing_info = new_fusing_info

def get_internal_graph(self):
"""Return the original graph."""
return self._internal_graph

def get_fusing_info(self):
"""Return the fusing information."""
return self._fusing_info

def is_part_of_fused_op(self, node):
"""Check if a node is part of any fused operation."""
return self._fusing_info.is_node_in_fused_op(node)

def _disable_nodes_activation_quantization(self):
"""
Disable activation for non-quantization needed due to fusion
Args:
nodes: nodes to update their activation quantization
"""
# TODO: temp disable activation quantization to keep similar functionality. This will be removed in the future
nodes_to_disable = self._fusing_info.get_nodes_to_disable_act_quantization()
for node in nodes_to_disable:
for qc in node.candidates_quantization_cfg:
qc.activation_quantization_cfg.enable_activation_quantization = False

def validate(self):
"""
Check if the internal graph and fusing data are consistent.
"""
return self._fusing_info.validate(self._internal_graph)
Loading