Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
30 commits
Select commit Hold shift + click to select a range
a0d7d40
Add fuse_op_quantization_config field to schema v2 of Fusing class
kkawa14 Apr 15, 2025
9225537
Renamed unittest filename
kkawa14 Apr 15, 2025
ac4db39
Rename test class name
kkawa14 Apr 15, 2025
6be96b7
Merge branch 'main' into add_fuse_op_quant_config_to_schema_v2
kkawa14 Apr 16, 2025
47f73ec
Merge branch 'main' into add_fuse_op_quant_config_to_schema_v2
kkawa14 Apr 17, 2025
407d840
Merge branch 'main' into add_fuse_op_quant_config_to_schema_v2
kkawa14 Apr 21, 2025
79ef579
Remove unnecessary validation
kkawa14 Apr 21, 2025
e59aff8
Merge branch 'main' into apply_quant_info_to_fusinginfo
kkawa14 Apr 21, 2025
9ecac66
Merge remote-tracking branch 'origin/add_fuse_op_quant_config_to_sche…
kkawa14 Apr 21, 2025
41e55be
Merge branch 'main' into apply_quant_info_to_fusinginfo
kkawa14 Apr 22, 2025
271b78c
Experimentally applying quant info to fusion info
kkawa14 Apr 22, 2025
9fd8fc3
Experimentally Code at fusing_info.py
kkawa14 Apr 23, 2025
d0bdc68
modify experimentally code
kkawa14 Apr 23, 2025
9fdb0e8
modify initalize method
kkawa14 Apr 23, 2025
f2eab66
Merge branch 'main' into apply_quant_info_to_fusinginfo
kkawa14 Apr 24, 2025
e41ce92
apply quantization info to fusinginfo
kkawa14 Apr 24, 2025
6238c4b
change variable name and conditional branching
kkawa14 Apr 24, 2025
fbf08de
adding unittest
kkawa14 Apr 25, 2025
637e2ea
Merge branch 'main' into apply_quant_info_to_fusinginfo
kkawa14 Apr 25, 2025
782620d
modify test expect(missing type)
kkawa14 Apr 25, 2025
06eb698
Merge branch 'main' into apply_quant_info_to_fusinginfo
kkawa14 May 7, 2025
87bb5fa
Fixed review feedback
kkawa14 May 7, 2025
8313c9f
Merge branch 'main' into apply_quant_info_to_fusinginfo
kkawa14 May 7, 2025
1d77024
Add other quant_op_cfg
kkawa14 May 7, 2025
861b390
modify
kkawa14 May 7, 2025
90d7612
Adding error conditions
kkawa14 May 7, 2025
764afa5
change the method of removing fused_op_id_to_quant_config
kkawa14 May 7, 2025
36d4bd2
fix line break
kkawa14 May 8, 2025
629f837
Merge branch 'main' into apply_quant_info_to_fusinginfo
kkawa14 May 9, 2025
35ffd84
Merge branch 'main' into apply_quant_info_to_fusinginfo
kkawa14 May 13, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions model_compression_toolkit/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -138,3 +138,8 @@
NODE_NAME = 'node_name'
TOTAL_SIZE = 'total_size'
NODE_OUTPUT_INDEX = 'node_output_index'


# Fusing Patterns constants
FUSED_LAYER_PATTERN = 'fused_layer_pattern'
FUSED_OP_QUANT_CONFIG = 'fused_op_quantization_config'
82 changes: 75 additions & 7 deletions model_compression_toolkit/core/common/fusion/fusing_info.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,8 @@
# ==============================================================================

from model_compression_toolkit.target_platform_capabilities import LayerFilterParams
from model_compression_toolkit.target_platform_capabilities.schema.mct_current_schema import OpQuantizationConfig
from model_compression_toolkit.constants import FUSED_LAYER_PATTERN, FUSED_OP_QUANT_CONFIG
from dataclasses import dataclass, field

from typing import Optional, List, Dict, Any, Tuple
Expand Down Expand Up @@ -41,6 +43,7 @@ class FusingInfo:
fusing_patterns: any = None
fusing_data: Dict[str, Tuple['BaseNode']] = field(default_factory=dict)
node_to_fused_node_map: Dict[str, str] = field(init=False, default_factory=dict)
fused_op_id_to_quant_config: Dict[str, OpQuantizationConfig] = field(default_factory=dict)

def __post_init__(self):
"""Validates and initializes mappings after dataclass instantiation."""
Expand All @@ -49,6 +52,7 @@ def __post_init__(self):
assert isinstance(op_nodes, tuple) and len(op_nodes) > 1, f"Found invalid fused op nodes: {op_nodes}"

self._init_node_mapping()
self._init_quantization_config_map()

def _init_node_mapping(self) -> None:
"""
Expand All @@ -59,6 +63,15 @@ def _init_node_mapping(self) -> None:
for node in nodes:
self.node_to_fused_node_map[node.name] = op_id

def _init_quantization_config_map(self) -> None:
"""
Init the mapping between fused operation IDs and their quantization configurations.
"""
self.fused_op_id_to_quant_config.clear()
if self.fusing_patterns is not None:
for op_id, nodes in self.fusing_data.items():
self.set_fused_op_quantization_config(op_id, nodes)

def add_fused_operation(self, op_id: str, nodes: Tuple['BaseNode']) -> None:
"""
Add a new fused operation with the given ID and set of nodes.
Expand All @@ -78,6 +91,22 @@ def add_fused_operation(self, op_id: str, nodes: Tuple['BaseNode']) -> None:
for node in nodes:
self.node_to_fused_node_map[node.name] = op_id

# Update the quantization config mapping for this operation
if self.fusing_patterns is not None:
self.set_fused_op_quantization_config(op_id, nodes)

def set_fused_op_quantization_config(self, op_id: str, nodes: Tuple['BaseNode']) -> None:
"""
Set the quantization configuration for a given fused operation ID.

Args:
op_id (str): The identifier for the fused operation.
nodes (Tuple[BaseNode]): The tuple of nodes that form the fused operation.
"""
fusing_pattern = next((fp for fp in self.fusing_patterns if is_valid_fusion([fp.get(FUSED_LAYER_PATTERN)], nodes)), None)
if fusing_pattern is not None:
self.fused_op_id_to_quant_config[op_id] = fusing_pattern.get(FUSED_OP_QUANT_CONFIG)

def remove_fused_operation(self, op_id: str) -> None:
"""
Remove a fused operation by its ID.
Expand All @@ -95,6 +124,7 @@ def remove_fused_operation(self, op_id: str) -> None:
for node in nodes:
self.node_to_fused_node_map.pop(node.name, None)
del self.fusing_data[op_id]
self.fused_op_id_to_quant_config.pop(op_id, None)

def get_fused_node_name(self, node_name: str) -> Optional[str]:
"""
Expand All @@ -117,6 +147,15 @@ def get_node_to_fused_node_map(self) -> Dict[str, str]:
"""
return self.node_to_fused_node_map.copy()

def get_fusing_quantization_config_map(self) -> Dict[str, OpQuantizationConfig]:
"""
Retrieve a copy of the mapping from fused operation IDs to their quantization configurations.

Returns:
A dictionary mapping each fused operation ID to its quantization configuration.
"""
return self.fused_op_id_to_quant_config.copy()

def get_fused_nodes(self, op_id: str) -> Optional[List['BaseNode']]:
"""
Retrieve the list of nodes for a given fused operation ID.
Expand All @@ -129,6 +168,18 @@ def get_fused_nodes(self, op_id: str) -> Optional[List['BaseNode']]:
"""
return self.fusing_data.get(op_id)

def get_fused_op_quantization_config(self, op_id: str) -> OpQuantizationConfig:
"""
Retrieve the quantization configuration for a given fused operation ID.

Args:
op_id (str): The identifier for the fused operation.

Returns:
OpQuantizationConfig: The quantization configuration for the operation, or None if not found.
"""
return self.fused_op_id_to_quant_config.get(op_id)

def is_node_in_fused_op(self, node: 'BaseNode') -> bool:
"""
Check if a node is part of any fused operation.
Expand Down Expand Up @@ -216,10 +267,11 @@ def validate(self, graph: 'Graph') -> None:
all_fused_nodes.update(node_set)

# Check 4: Ensure the sequence matches a valid fusing pattern
if not is_valid_fusion(self.fusing_patterns, nodes):
valid_fusing_patterns = _get_fusing_layer_patterns(self.fusing_patterns)
if not is_valid_fusion(valid_fusing_patterns, nodes):
raise ValueError(
f"Fused operation {op_id} does not match any valid fusing pattern "
f"from {self.fusing_patterns}."
f"from {valid_fusing_patterns}."
)

def is_nodes_eligible_to_be_fused(self, nodes: List['BaseNode']) -> bool:
Expand All @@ -240,7 +292,8 @@ def is_nodes_eligible_to_be_fused(self, nodes: List['BaseNode']) -> bool:
return False

# Check if the provided nodes match a valid fusion pattern
return is_valid_fusion(fusing_patterns=self.fusing_patterns, nodes=nodes)
valid_fusing_patterns = _get_fusing_layer_patterns(self.fusing_patterns)
return is_valid_fusion(fusing_patterns=valid_fusing_patterns, nodes=nodes)

def __repr__(self) -> str:
"""
Expand Down Expand Up @@ -287,8 +340,11 @@ def generate_fusing_info(self, graph: 'Graph') -> FusingInfo:
if not self._fusing_patterns:
return FusingInfo(fusing_patterns=self._fusing_patterns)

# Extract fusing layer patterns
fusing_layer_patterns = _get_fusing_layer_patterns(self._fusing_patterns)

# Find max fusion
max_layers_fusing = max([len(fusing_pattern) for fusing_pattern in self._fusing_patterns])
max_layer_patterns = max([len(fusing_layer_pattern) for fusing_layer_pattern in fusing_layer_patterns])

# Travel along the graph to find layers for fusing
nodes = graph.get_topo_sorted_nodes()
Expand All @@ -302,9 +358,9 @@ def generate_fusing_info(self, graph: 'Graph') -> FusingInfo:
continue
# Start fusing search
fusing_nodes = [] # nodes that are candidates for participating in fusing
patterns = copy.deepcopy(self._fusing_patterns)
patterns = copy.deepcopy(fusing_layer_patterns)
next_nodes = [node]
for i in range(max_layers_fusing):
for i in range(max_layer_patterns):
patterns = get_valid_fusing_patterns_for_node(patterns, next_nodes[0], i)
if len(patterns) == 0: # Give up if no more fusion pattern
break
Expand All @@ -314,7 +370,7 @@ def generate_fusing_info(self, graph: 'Graph') -> FusingInfo:
break

# New fusion
if is_valid_fusion(self._fusing_patterns, fusing_nodes):
if is_valid_fusion(fusing_layer_patterns, fusing_nodes):
fused_op_id = FusingInfo.generate_fused_op_id(fusing_nodes)
assert fused_op_id not in fusing_info, f"{fused_op_id} is already in fusing info: {fusing_info}"
fusing_info[fused_op_id] = tuple(fusing_nodes)
Expand Down Expand Up @@ -371,3 +427,15 @@ def is_valid_fusion(fusing_patterns: List[List[Any]], nodes: List['BaseNode']) -
if counter == fusion_depth:
return True
return False


def _get_fusing_layer_patterns(fusing_patterns: List[Dict[Any, OpQuantizationConfig]]) -> List[List[Any]]:
"""
Extracts the fusing layer patterns from the provided fusing patterns.
Args:
fusing_patterns: List of patterns of layers/LayerFilterParams to fuse and their mapping quantization config.

Returns:
supported fusing layer patterns
"""
return [f.get(FUSED_LAYER_PATTERN) for f in fusing_patterns]
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,9 @@
OpQuantizationConfig, QuantizationConfigOptions
from model_compression_toolkit.target_platform_capabilities.targetplatform2framework.current_tpc import _current_tpc

from model_compression_toolkit.constants import FUSED_LAYER_PATTERN, FUSED_OP_QUANT_CONFIG


class FrameworkQuantizationCapabilities(ImmutableClass):
"""
Attach framework information to a modeled hardware.
Expand Down Expand Up @@ -94,20 +97,26 @@ def get_layers_by_opset(self, op: OperatorsSetBase) -> List[Any]:
"""
return self.op_sets_to_layers.get_layers_by_op(op)

def get_fusing_patterns(self) -> List[List[Any]]:
def get_fusing_patterns(self) -> List[Dict[List[Any], OpQuantizationConfig]]:
"""

Returns: List of patterns of layers/LayerFilterParams to fuse.
Returns: List of patterns of layers/LayerFilterParams to fuse and their mapping quantization config.

"""
res = []

patterns = []
if self.tpc.fusing_patterns is None:
return res
return patterns

for p in self.tpc.fusing_patterns:
res = []
ops = [self.get_layers_by_opset(x) for x in p.operator_groups]
res.extend(itertools.product(*ops))
return [list(x) for x in res]

fused_op_quant_config = getattr(p, FUSED_OP_QUANT_CONFIG, None)
patterns.extend({FUSED_LAYER_PATTERN: list(x), FUSED_OP_QUANT_CONFIG: fused_op_quant_config} for x in res)

return patterns

def get_info(self) -> Dict[str, Any]:
"""
Expand Down
4 changes: 2 additions & 2 deletions tests/keras_tests/non_parallel_tests/test_keras_tpc.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@
from keras import Input

import model_compression_toolkit as mct
from model_compression_toolkit.constants import TENSORFLOW
from model_compression_toolkit.constants import TENSORFLOW, FUSED_LAYER_PATTERN, FUSED_OP_QUANT_CONFIG
from model_compression_toolkit.target_platform_capabilities.constants import DEFAULT_TP_MODEL, IMX500_TP_MODEL, \
QNNPACK_TP_MODEL, TFLITE_TP_MODEL, KERNEL_ATTR, BIAS_ATTR, KERAS_KERNEL, BIAS, WEIGHTS_N_BITS
from model_compression_toolkit.core.keras.keras_implementation import KerasImplementation
Expand Down Expand Up @@ -250,7 +250,7 @@ def test_keras_fusing_patterns(self):

fusings = hm_keras.get_fusing_patterns()
self.assertEqual(len(fusings), 2)
p0, p1 = fusings[0], fusings[1]
p0, p1 = fusings[0].get(FUSED_LAYER_PATTERN), fusings[1].get(FUSED_LAYER_PATTERN)

self.assertEqual(len(p0), 3)
self.assertEqual(p0[0], Conv2D)
Expand Down
10 changes: 5 additions & 5 deletions tests/pytorch_tests/function_tests/test_pytorch_tpc.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@
import model_compression_toolkit.target_platform_capabilities.schema.mct_current_schema as schema
from model_compression_toolkit.core import MixedPrecisionQuantizationConfig
from model_compression_toolkit.defaultdict import DefaultDict
from model_compression_toolkit.constants import PYTORCH
from model_compression_toolkit.constants import PYTORCH, FUSED_LAYER_PATTERN, FUSED_OP_QUANT_CONFIG
from model_compression_toolkit.core.common import BaseNode
from model_compression_toolkit.target_platform_capabilities.constants import DEFAULT_TP_MODEL, IMX500_TP_MODEL, \
TFLITE_TP_MODEL, QNNPACK_TP_MODEL, KERNEL_ATTR, WEIGHTS_N_BITS, PYTORCH_KERNEL, BIAS_ATTR, BIAS
Expand Down Expand Up @@ -236,15 +236,15 @@ def test_pytorch_fusing_patterns(self):
fusing_patterns=tuple(fusing_patterns),
add_metadata=False)

hm_keras = FrameworkQuantizationCapabilities(hm)
with hm_keras:
hm_torch = FrameworkQuantizationCapabilities(hm)
with hm_torch:
OperationsSetToLayers("opA", [torch.conv2d])
OperationsSetToLayers("opB", [torch.tanh])
OperationsSetToLayers("opC", [LayerFilterParams(torch.relu, Greater("max_value", 7), negative_slope=0)])

fusings = hm_keras.get_fusing_patterns()
fusings = hm_torch.get_fusing_patterns()
self.assertEqual(len(fusings), 2)
p0, p1 = fusings[0], fusings[1]
p0, p1 = fusings[0].get(FUSED_LAYER_PATTERN), fusings[1].get(FUSED_LAYER_PATTERN)

self.assertEqual(len(p0), 3)
self.assertEqual(p0[0], torch.conv2d)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
import pytest
from model_compression_toolkit.core.common.graph.base_graph import OutTensor

from model_compression_toolkit.constants import FLOAT_BITWIDTH
from model_compression_toolkit.constants import FLOAT_BITWIDTH, FUSED_LAYER_PATTERN, FUSED_OP_QUANT_CONFIG
from model_compression_toolkit.core import ResourceUtilization
from model_compression_toolkit.core.common import Graph
from model_compression_toolkit.core.common.fusion.fusing_info import FusingInfo
Expand Down Expand Up @@ -555,7 +555,8 @@ def test_compute_cuts_random_fusion_valid_utilization(self, seed, disable_quanti
if i + fuse_len <= num_nodes:
fused = tuple(nodes[j] for j in range(i, i + fuse_len))
fused_name = f"FusedNode_{'_'.join(n.name for n in fused)}"
fused_patterns.append([n.layer_class for n in fused])
fused_pattern = {FUSED_LAYER_PATTERN: [n.layer_class for n in fused], FUSED_OP_QUANT_CONFIG: None}
fused_patterns.append(fused_pattern)
fused_data[fused_name] = fused
i += fuse_len
else:
Expand Down
Loading