Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
19 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions model_compression_toolkit/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -138,3 +138,8 @@
NODE_NAME = 'node_name'
TOTAL_SIZE = 'total_size'
NODE_OUTPUT_INDEX = 'node_output_index'


# Fusing Patterns constants
FUSED_LAYER_PATTERN = 'fused_layer_pattern'
FUSED_OP_QUANT_CONFIG = 'fused_op_quantization_config'
58 changes: 52 additions & 6 deletions model_compression_toolkit/core/common/fusion/fusing_info.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,8 @@
# ==============================================================================

from model_compression_toolkit.target_platform_capabilities import LayerFilterParams
from model_compression_toolkit.target_platform_capabilities.schema.mct_current_schema import OpQuantizationConfig
from model_compression_toolkit.constants import FUSED_LAYER_PATTERN, FUSED_OP_QUANT_CONFIG
from dataclasses import dataclass, field

from typing import Optional, List, Dict, Any, Tuple
Expand Down Expand Up @@ -41,6 +43,7 @@ class FusingInfo:
fusing_patterns: any = None
fusing_data: Dict[str, Tuple['BaseNode']] = field(default_factory=dict)
node_to_fused_node_map: Dict[str, str] = field(init=False, default_factory=dict)
fusing_data_to_quantization_config_map: Dict[str, OpQuantizationConfig] = field(default_factory=dict)

def __post_init__(self):
"""Validates and initializes mappings after dataclass instantiation."""
Expand All @@ -49,6 +52,7 @@ def __post_init__(self):
assert isinstance(op_nodes, tuple) and len(op_nodes) > 1, f"Found invalid fused op nodes: {op_nodes}"

self._init_node_mapping()
self._init_quantization_config_map()

def _init_node_mapping(self) -> None:
"""
Expand All @@ -59,6 +63,17 @@ def _init_node_mapping(self) -> None:
for node in nodes:
self.node_to_fused_node_map[node.name] = op_id

def _init_quantization_config_map(self) -> None:
"""
Init the mapping between fused operation IDs and their quantization configurations.
"""
self.fusing_data_to_quantization_config_map.clear()
if self.fusing_patterns is not None:
for op_id, nodes in self.fusing_data.items():
for fusing_pattern in self.fusing_patterns:
if is_valid_fusion([fusing_pattern], nodes):
self.fusing_data_to_quantization_config_map[op_id] = fusing_pattern.get(FUSED_OP_QUANT_CONFIG)

def add_fused_operation(self, op_id: str, nodes: Tuple['BaseNode']) -> None:
"""
Add a new fused operation with the given ID and set of nodes.
Expand All @@ -78,6 +93,12 @@ def add_fused_operation(self, op_id: str, nodes: Tuple['BaseNode']) -> None:
for node in nodes:
self.node_to_fused_node_map[node.name] = op_id

# Update the quantization config mapping for this operation
if self.fusing_patterns is not None:
for fusing_pattern in self.fusing_patterns:
if is_valid_fusion([fusing_pattern], nodes):
self.fusing_data_to_quantization_config_map[op_id] = fusing_pattern.get(FUSED_OP_QUANT_CONFIG)

def remove_fused_operation(self, op_id: str) -> None:
"""
Remove a fused operation by its ID.
Expand All @@ -95,6 +116,7 @@ def remove_fused_operation(self, op_id: str) -> None:
for node in nodes:
self.node_to_fused_node_map.pop(node.name, None)
del self.fusing_data[op_id]
del self.fusing_data_to_quantization_config_map[op_id]

def get_fused_node_name(self, node_name: str) -> Optional[str]:
"""
Expand All @@ -117,6 +139,15 @@ def get_node_to_fused_node_map(self) -> Dict[str, str]:
"""
return self.node_to_fused_node_map.copy()

def get_fusing_quantization_config_map(self) -> Dict[str, OpQuantizationConfig]:
"""
Retrieve a copy of the mapping from fused operation IDs to their quantization configurations.

Returns:
A dictionary mapping each fused operation ID to its quantization configuration.
"""
return self.fusing_data_to_quantization_config_map.copy()

def get_fused_nodes(self, op_id: str) -> Optional[List['BaseNode']]:
"""
Retrieve the list of nodes for a given fused operation ID.
Expand All @@ -129,6 +160,18 @@ def get_fused_nodes(self, op_id: str) -> Optional[List['BaseNode']]:
"""
return self.fusing_data.get(op_id)

def get_fused_op_quantization_config(self, op_id: str) -> OpQuantizationConfig:
"""
Retrieve the quantization configuration for a given fused operation ID.

Args:
op_id (str): The identifier for the fused operation.

Returns:
OpQuantizationConfig: The quantization configuration for the operation, or None if not found.
"""
return self.fusing_data_to_quantization_config_map.get(op_id)

def is_node_in_fused_op(self, node: 'BaseNode') -> bool:
"""
Check if a node is part of any fused operation.
Expand Down Expand Up @@ -284,11 +327,13 @@ def generate_fusing_info(self, graph: 'Graph') -> FusingInfo:
- Fusions are linear sequences (each node has exactly one successor).
- Each node belongs to at most one fused operation.
"""
if not self._fusing_patterns:
fusing_layer_patterns = [f.get(FUSED_LAYER_PATTERN) for f in self._fusing_patterns]

if not fusing_layer_patterns:
return FusingInfo(fusing_patterns=self._fusing_patterns)

# Find max fusion
max_layers_fusing = max([len(fusing_pattern) for fusing_pattern in self._fusing_patterns])
max_fusing_layer_patterns = max([len(fusing_pattern.get(FUSED_LAYER_PATTERN)) for fusing_pattern in self._fusing_patterns])

# Travel along the graph to find layers for fusing
nodes = graph.get_topo_sorted_nodes()
Expand All @@ -302,9 +347,9 @@ def generate_fusing_info(self, graph: 'Graph') -> FusingInfo:
continue
# Start fusing search
fusing_nodes = [] # nodes that are candidates for participating in fusing
patterns = copy.deepcopy(self._fusing_patterns)
patterns = copy.deepcopy(fusing_layer_patterns)
next_nodes = [node]
for i in range(max_layers_fusing):
for i in range(max_fusing_layer_patterns):
patterns = get_valid_fusing_patterns_for_node(patterns, next_nodes[0], i)
if len(patterns) == 0: # Give up if no more fusion pattern
break
Expand Down Expand Up @@ -361,10 +406,11 @@ def is_valid_fusion(fusing_patterns: List[List[Any]], nodes: List['BaseNode']) -
if fusion_depth <= 1:
return False
for fusing_pattern in fusing_patterns:
if fusion_depth != len(fusing_pattern):
fusing_layer_pattern = fusing_pattern.get(FUSED_LAYER_PATTERN)
if fusion_depth != len(fusing_layer_pattern):
continue
counter = 0
for i, layer in enumerate(fusing_pattern):
for i, layer in enumerate(fusing_layer_pattern):
if (type(layer) == LayerFilterParams and nodes[i].is_match_filter_params(layer)) or \
nodes[i].is_match_type(layer):
counter += 1
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,9 @@
OpQuantizationConfig, QuantizationConfigOptions
from model_compression_toolkit.target_platform_capabilities.targetplatform2framework.current_tpc import _current_tpc

from model_compression_toolkit.constants import FUSED_LAYER_PATTERN, FUSED_OP_QUANT_CONFIG


class FrameworkQuantizationCapabilities(ImmutableClass):
"""
Attach framework information to a modeled hardware.
Expand Down Expand Up @@ -97,17 +100,30 @@ def get_layers_by_opset(self, op: OperatorsSetBase) -> List[Any]:
def get_fusing_patterns(self) -> List[List[Any]]:
"""

Returns: List of patterns of layers/LayerFilterParams to fuse.
Returns: List of patterns of layers/LayerFilterParams to fuse and their mapping quantization config.

"""
res = []

patterns = []
if self.tpc.fusing_patterns is None:
return res
return patterns

for p in self.tpc.fusing_patterns:
res = []
ops = [self.get_layers_by_opset(x) for x in p.operator_groups]
res.extend(itertools.product(*ops))
return [list(x) for x in res]

if hasattr(p, 'fused_op_quantization_config'):
fused_op_quantization_config = p.fused_op_quantization_config
else:
fused_op_quantization_config = None

for x in res:
pattern = {FUSED_LAYER_PATTERN: list(x),
FUSED_OP_QUANT_CONFIG: fused_op_quantization_config}
patterns.append(pattern)

return patterns

def get_info(self) -> Dict[str, Any]:
"""
Expand Down
166 changes: 165 additions & 1 deletion tests_pytest/common_tests/unit_tests/core/test_fusion_info.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,12 @@
from model_compression_toolkit.core.common.fusion.fusing_info import FusingInfoGenerator, FUSED_OP_ID_PREFIX, FusingInfo
from model_compression_toolkit.target_platform_capabilities import FrameworkQuantizationCapabilities
from model_compression_toolkit.core.common import BaseNode
from model_compression_toolkit.constants import FUSED_LAYER_PATTERN, FUSED_OP_QUANT_CONFIG

from tests.common_tests.helpers.generate_test_tpc import generate_test_attr_configs, generate_test_op_qc

# Setup TEST_QC and TEST_QCO for testing.
TEST_QC = generate_test_op_qc(**generate_test_attr_configs())


class MockBaseNode:
Expand All @@ -42,7 +48,8 @@ def fusing_patterns():
"""
- Returns predefined fusing patterns: Conv2D + ReLU and Linear + Softmax.
"""
return [["Conv2d", "ReLU"], ["Linear", "Softmax"]]
return [{FUSED_LAYER_PATTERN: ["Conv2d", "ReLU"], FUSED_OP_QUANT_CONFIG: None},
{FUSED_LAYER_PATTERN: ["Linear", "Softmax"], FUSED_OP_QUANT_CONFIG: None}]


@pytest.fixture
Expand Down Expand Up @@ -219,3 +226,160 @@ def test_is_node_in_fused_op_returns_false_for_absent_node(mock_graph, fusing_in
unrelated = MockBaseNode("unrelated")
assert not fi.is_node_in_fused_op(unrelated)



def create_mock_base_node(name: str, layer_class: str):
"""
Function for creating the mock nodes required for a simple neural network structure.
Enables node name, layer class, type, and type checking method.
"""

dummy_initalize = {'framework_attr': {},
'input_shape': (),
'output_shape': (),
'weights': {}}

real_node = BaseNode(name=name, layer_class=layer_class, **dummy_initalize)

node = Mock(spec=real_node)
node.is_match_type = real_node.is_match_type
node.layer_class = layer_class
node.name = name

return node

@pytest.fixture
def fusing_patterns_with_qconfig():
"""
- Returns predefined fusing patterns: Conv2D + ReLU and Conv2D + Tanh, Linear + Softmax.
"""
return [{FUSED_LAYER_PATTERN: ["Conv2d", "ReLU"], FUSED_OP_QUANT_CONFIG: TEST_QC},
{FUSED_LAYER_PATTERN: ["Conv2d", "Tanh"], FUSED_OP_QUANT_CONFIG: None},
{FUSED_LAYER_PATTERN: ["Linear", "Softmax"], FUSED_OP_QUANT_CONFIG: TEST_QC }]

@pytest.fixture
def fusing_info_generator_with_qconfig(fusing_patterns_with_qconfig):
"""
Creates a FusingInfoGenerator using the fusing patterns.
"""
return FusingInfoGenerator(fusing_patterns_with_qconfig)

@pytest.fixture
def mock_qconfig_set_nodes():
"""
Creates mock nodes representing a simple neural network structure.
- Nodes: Conv2D, ReLU, Conv2D, Tanh, Linear, Softmax.
"""
node1 = create_mock_base_node(name='conv', layer_class='Conv2d')
node2 = create_mock_base_node(name='relu', layer_class='ReLU')
node3 = create_mock_base_node(name='conv_2', layer_class='Conv2d')
node4 = create_mock_base_node(name='tanh', layer_class='Tanh')
node5 = create_mock_base_node(name='linear', layer_class='Linear')
node6 = create_mock_base_node(name='softmax', layer_class='Softmax')

return [node1, node2, node3, node4, node5, node6]


@pytest.fixture
def mock_qconfig_set_graph(mock_qconfig_set_nodes):
"""
Creates a mock graph with topologically sorted nodes and defined connectivity.
- Implements `get_next_nodes` and `get_prev_nodes` to maintain linear order.
"""
mock_nodes = mock_qconfig_set_nodes

graph = Mock()
graph.nodes.return_value = mock_nodes
graph.get_topo_sorted_nodes.return_value = mock_nodes

adjacency = {
mock_nodes[0]: [mock_nodes[1]], # conv -> relu
mock_nodes[1]: [mock_nodes[2]], # relu -> conv_2
mock_nodes[2]: [mock_nodes[3]], # conv_2 -> silu
mock_nodes[3]: [mock_nodes[4]], # silu -> linear
mock_nodes[4]: [mock_nodes[5]], # linear -> softmax
mock_nodes[5]: [] # softmax has no outputs
}

reverse_adjacency = {
mock_nodes[0]: [], # conv has no inputs
mock_nodes[1]: [mock_nodes[0]], # relu <- conv
mock_nodes[2]: [mock_nodes[1]], # conv_2 <- relu
mock_nodes[3]: [mock_nodes[2]], # silu <- conv_2
mock_nodes[4]: [mock_nodes[3]], # linear <- silu
mock_nodes[5]: [mock_nodes[4]] # softmax <- linear
}

graph.get_next_nodes.side_effect = lambda node: adjacency.get(node, [])
graph.get_prev_nodes.side_effect = lambda node: reverse_adjacency.get(node, [])

return graph


def test_fusing_info_qconfig_mapping(mock_qconfig_set_graph, fusing_info_generator_with_qconfig):
"""
Tests that each node is correctly mapped to its fused quantization configs.
"""
fi = fusing_info_generator_with_qconfig.generate_fusing_info(mock_qconfig_set_graph)
fi_qconfig_map = fi.get_fusing_quantization_config_map()

expected_op1_id = f"{FUSED_OP_ID_PREFIX}conv_relu"
expected_op2_id = f"{FUSED_OP_ID_PREFIX}conv_2_tanh"
expected_op3_id = f"{FUSED_OP_ID_PREFIX}linear_softmax"

assert len(fi_qconfig_map) == 3
assert fi_qconfig_map[expected_op1_id] == TEST_QC
assert fi_qconfig_map[expected_op2_id] == None
assert fi_qconfig_map[expected_op3_id] == TEST_QC


def test_add_fused_operation_adds_data_and_qconfig(mock_qconfig_set_graph, fusing_info_generator_with_qconfig):
"""
Tests whether the added node is correctly assigned the fused quantization config.
"""

fi = fusing_info_generator_with_qconfig.generate_fusing_info(mock_qconfig_set_graph)
fi_qconfig_map = fi.get_fusing_quantization_config_map()

### Checking the number of mappings before addition
assert len(fi_qconfig_map) == 3

node1 = create_mock_base_node(name='conv_a', layer_class='Conv2d')
node2 = create_mock_base_node(name='relu_b', layer_class='ReLU')

op_id = f"{FUSED_OP_ID_PREFIX}conv_a_relu_b"
fi.add_fused_operation(op_id, (node1, node2))
fi_qconfig_map = fi.get_fusing_quantization_config_map()

### Checking the mapping information after addition
assert op_id in fi.get_all_fused_operations()
assert fi.get_fused_node_name("conv_a") == op_id
assert fi.get_fused_node_name("relu_b") == op_id

assert len(fi_qconfig_map) == 4
assert fi.get_fused_op_quantization_config(op_id) == TEST_QC


def test_remove_fusing_data_and_qconfig(mock_qconfig_set_graph, fusing_info_generator_with_qconfig, mock_qconfig_set_nodes):
"""
Tests that the fused quantization config for the specified operation is removed from the map.
"""

fi = fusing_info_generator_with_qconfig.generate_fusing_info(mock_qconfig_set_graph)

### Delete Conv2D + ReLU pattern.
conv_node, relu_node, _, _, _, _ = mock_qconfig_set_nodes
op_id = f"{FUSED_OP_ID_PREFIX}conv_relu"

### Checking the mapping information before deletion.
assert len(fi.get_fusing_quantization_config_map()) == 3
assert fi.get_fused_op_quantization_config(op_id) == TEST_QC
assert fi.get_fused_nodes(op_id) == (conv_node, relu_node)

fi.remove_fused_operation(op_id)
fi_qconfig_map = fi.get_fusing_quantization_config_map()

### Checking the mapping information after deletion.
assert len(fi.get_fusing_quantization_config_map()) == 2
assert fi.get_fused_op_quantization_config(op_id) == None
assert fi.get_fused_nodes(op_id) == None
Loading