-
Notifications
You must be signed in to change notification settings - Fork 79
Refactor fusing #1386
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
Merged
Refactor fusing #1386
Changes from 49 commits
Commits
Show all changes
63 commits
Select commit
Hold shift + click to select a range
1cc4739
Fusing refactor
a2e4632
remove commented out test file
1485cdc
revert changes in keras mixin
e08676f
remove unneeded fusion tests packages hirarchy
bb8cf01
revert old tests that were commented out
0bb6f5a
fix wrong type hint in test_activation_weights_composition_substition
d7135d6
remove commented out code from runner
8f3b1d9
use internal graph for the final qat model instead of the fused graph
4cc1533
add comment for second moment correction about the fusing info correc…
3f6fe40
remove old fusing data from graph
2ea3f2a
update old pytorch tests
595f4bd
add check for graph type in torch model builder
a9ae9f9
add comments to fusing info
e1cad46
add comments to graph fuser and graph with metadata
43fbf1d
Set version for onnxruntime-extensions
5c16f03
adapt keras unit tests
7c52939
Revert "Set version for onnxruntime-extensions"
7619677
Merge remote-tracking branch 'origin/main' into refactor-fusing
8406b2c
fix comments in fusing info
fb0a30b
remove the deepcopy in get_all_fused_operations
ec583f1
verify fusing info is consistent when using graph fuser
44f2256
move function to disable activation quantization from fusing info to …
5544545
use prefic of op id as constant
6e7d170
pass only fusing patterns instead of entire fqc
1ff0cc9
use dataclass for FusingInfo
c1a413b
rename filter_fusing_patterns
ce79fc8
remove duplicate def of FusedLayerType
459394e
rename fusing mateadata wrapper
671c58a
fix old name of FusingMetadataWrapper in comments
34c2820
fix comments in fuse method of GraphFuser
0054571
rename fuse method in graph fuser
bd75774
save fusing info instead of multiple fetches
b1be93a
set the graph without the fusing metadata in the pytorch back2fw
a589333
add missing types in BaseGraphWithFusingMetadataTest
30e3645
adjust the test of activation-weight composition due to the fusion re…
e7ac1de
migrate old unittests
f5c0a07
tests with multiple successors/predecessors
7551d34
add test that checks the case of new fusion due to new node
98c14f5
add test that checks a valid graph change does not fail the validatio…
8ee3191
merge changes with main
d78ac46
move tests to a designated package
4c396ae
fix batchnorm_reconstruction due to replacing list with tuple in fus…
4580a98
use internal graph for mp for now
3ab79d9
fix missing import of removed test
63a7dff
fix wrong path in test_cfg_candidates_filter
c170f2a
run pytest before unittests
773c115
add missing license
d31c53c
extract minimal_cfg_options from minimal_tpc
106e694
fix new argument in test_cfg_candidates_filter
a16e47a
replace wrapper with embedding the fusing info in the graph
f485a36
rewrite the funsions in test_fusing_info more explicity
8a0334b
return fusion to test in test_activation_weights_composition_substitu…
68b560c
disable validation in tests of virtual graph
d097e77
fix tests
2011098
merge from main
48ab653
rename tests files to remove wrapper
e112c65
run pytests in keras before unittests
b640c45
add integration tests for GraphFuser
39dcce1
disable validation in keras function tests for virtual graph
57320b0
use tuple for the op cfg options instead of list
3031dba
fix tpc in tests where default options had more than one option
e81ebf5
extend fusing info tests
7000c8d
change skip_validation flag to protected
File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
402 changes: 402 additions & 0 deletions
402
model_compression_toolkit/core/common/fusion/fusing_info.py
Large diffs are not rendered by default.
Oops, something went wrong.
159 changes: 159 additions & 0 deletions
159
model_compression_toolkit/core/common/fusion/fusing_metadata_wrapper.py
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,159 @@ | ||
| # Copyright 2025 Sony Semiconductor Israel, Inc. All rights reserved. | ||
| # | ||
| # Licensed under the Apache License, Version 2.0 (the "License"); | ||
| # you may not use this file except in compliance with the License. | ||
| # You may obtain a copy of the License at | ||
| # | ||
| # http://www.apache.org/licenses/LICENSE-2.0 | ||
| # | ||
| # Unless required by applicable law or agreed to in writing, software | ||
| # distributed under the License is distributed on an "AS IS" BASIS, | ||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
| # See the License for the specific language governing permissions and | ||
| # limitations under the License. | ||
| # ============================================================================== | ||
|
|
||
| import types | ||
|
|
||
| from functools import wraps | ||
|
|
||
| from typing import Any, Iterator | ||
|
|
||
| from model_compression_toolkit.core.common import BaseNode, Graph | ||
| from model_compression_toolkit.core.common.fusion.fusing_info import FusingInfo | ||
|
|
||
|
|
||
|
|
||
| class FusingMetadataWrapper: | ||
| def __init__(self, graph: Graph, fusing_info: FusingInfo): | ||
| """ | ||
| Initialize with a graph and its fusing information. | ||
|
|
||
| Args: | ||
| graph: The neural network graph (e.g., a networkx.DiGraph or similar). | ||
| fusing_info: Dict mapping fused operation IDs to sets of node objects. | ||
| """ | ||
| assert isinstance(graph, Graph) | ||
| self._internal_graph = graph | ||
| self._fusing_info = fusing_info | ||
| self._fusing_info.validate(graph) # Ensure initial consistency | ||
| # TODO: temp disable activation quantization to keep similar functionality. This will be removed in the future | ||
| self._disable_nodes_activation_quantization() | ||
|
|
||
| # We added __getstate__ and __setstate__ to FusingMetadataWrapper to fix a recursion error during copy.deepcopy. Without | ||
| # these, deepcopy endlessly traverses attributes via __getattr__, causing a loop. Now, __getstate__ defines what | ||
| # to copy (self._graph and self._fusing_info), and __setstate__ rebuilds the object, ensuring a clean copy | ||
| # without recursion, assuming Graph and FusingInfo are copyable. | ||
| def __getstate__(self): | ||
| """ | ||
| Define how the object is serialized for copying. | ||
| Returns a dictionary of the essential attributes. | ||
| """ | ||
| self._fusing_info.validate(self._internal_graph) | ||
| return self.__dict__.copy() | ||
|
|
||
| def __setstate__(self, state): | ||
| """ | ||
| Reconstruct the object from the serialized state. | ||
|
|
||
| Args: | ||
| state: Dictionary containing the serialized attributes. | ||
| """ | ||
| self.__dict__.update(state) | ||
| self._fusing_info.validate(self._internal_graph) | ||
|
|
||
| def __getattr__(self, name: str) -> Any: | ||
| """ | ||
| Delegate attribute access to the underlying graph if not found in FusingMetadataWrapper. | ||
|
|
||
| Ensures that if the accessed attribute is a callable (e.g., a method like remove_node), | ||
| it is wrapped so that the fusing information is validated after execution. | ||
| Non-callable attributes are returned directly without validation. | ||
|
|
||
| Args: | ||
| name: The name of the attribute being accessed. | ||
|
|
||
| Returns: | ||
| The attribute or a wrapped method from self._graph. | ||
|
|
||
| Raises: | ||
| AttributeError: If the attribute doesn't exist in self._graph. | ||
| """ | ||
|
|
||
| # TODO: Optimize validation by restricting it to known modifying methods to improve efficiency. For now, | ||
| # validating after every method call ensures correctness. In the | ||
| # future, define explicit modification methods (e.g., remove_node) | ||
| # in FusingMetadataWrapper for better efficiency. | ||
|
|
||
| graph_attr = getattr(self._internal_graph, name) | ||
| # Only wrap methods or functions, excluding properties and descriptors | ||
| if isinstance(graph_attr, (types.MethodType, types.FunctionType)): | ||
| @wraps(graph_attr) | ||
| def wrapper(*args, **kwargs): | ||
| result = graph_attr(*args, **kwargs) | ||
| self._fusing_info.validate(self._internal_graph) | ||
| return result | ||
| return wrapper | ||
|
|
||
| return graph_attr | ||
|
|
||
| def __iter__(self) -> Iterator[BaseNode]: | ||
| """ | ||
| Make FusingMetadataWrapper iterable by delegating to the wrapped graph's iterator. | ||
|
|
||
| This allows FusingMetadataWrapper to be used in contexts expecting an iterable of nodes, | ||
| such as topological_sort, without requiring changes to external code. | ||
|
|
||
| Returns: | ||
| An iterator over the nodes in the underlying graph. | ||
| """ | ||
| return iter(self._internal_graph) | ||
|
|
||
| def __getitem__(self, key: Any) -> Any: | ||
| """ | ||
| Delegate subscripting to the wrapped graph. | ||
|
|
||
| This enables FusingMetadataWrapper to support dictionary-like access (e.g., graph[node][child]) | ||
| as required by operations like topological_generations in NetworkX, maintaining | ||
| compatibility with code expecting a subscriptable Graph object. | ||
|
|
||
| Args: | ||
| key: The key (e.g., node) to look up in the graph. | ||
|
|
||
| Returns: | ||
| The value associated with the key in the wrapped graph. | ||
|
|
||
| Raises: | ||
| KeyError: If the key doesn't exist in self._graph. | ||
| """ | ||
| return self._internal_graph[key] | ||
|
|
||
| def update_fusing_info(self, new_fusing_info: FusingInfo): | ||
| self._fusing_info = new_fusing_info | ||
|
|
||
| def get_internal_graph(self): | ||
| """Return the original graph.""" | ||
| return self._internal_graph | ||
|
|
||
| def get_fusing_info(self): | ||
| """Return the fusing information.""" | ||
| return self._fusing_info | ||
|
|
||
| def is_part_of_fused_op(self, node): | ||
| """Check if a node is part of any fused operation.""" | ||
| return self._fusing_info.is_node_in_fused_op(node) | ||
|
|
||
| def _disable_nodes_activation_quantization(self): | ||
| """ | ||
| Disable activation quantization for nodes inside the fused operators. | ||
| """ | ||
| nodes_to_disable = [node for nodes in self._fusing_info.get_all_fused_operations().values() for node in nodes[:-1]] | ||
| for node in nodes_to_disable: | ||
| for qc in node.candidates_quantization_cfg: | ||
| qc.activation_quantization_cfg.enable_activation_quantization = False | ||
|
|
||
| def validate(self): | ||
| """ | ||
| Check if the internal graph and fusing data are consistent. | ||
| """ | ||
| return self._fusing_info.validate(self._internal_graph) |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is the todo here planned for this PR?
is it necessary actually? because you can always retrieve the original weights from the original graph
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
No, in the PR that handles the MP.