Skip to content

Commit 056da80

Browse files
authored
Add functionality to FusingInfo class to preserve quantization config (#1426)
* Add fuse_op_quantization_config field to schema v2 of Fusing class * Renamed unittest filename * Rename test class name * Remove unnecessary validation * Experimentally applying quant info to fusion info * Experimentally Code at fusing_info.py * modify experimentally code * modify initalize method * apply quantization info to fusinginfo * change variable name and conditional branching * adding unittest * modify test expect(missing type) * Fixed review feedback ・change for fusing_info.py and fqc's get_fusing_patterns method. ・change test patterns. * Add other quant_op_cfg * modify * Adding error conditions * change the method of removing fused_op_id_to_quant_config * fix line break
1 parent 09ed05c commit 056da80

File tree

9 files changed

+382
-36
lines changed

9 files changed

+382
-36
lines changed

model_compression_toolkit/constants.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -138,3 +138,8 @@
138138
NODE_NAME = 'node_name'
139139
TOTAL_SIZE = 'total_size'
140140
NODE_OUTPUT_INDEX = 'node_output_index'
141+
142+
143+
# Fusing Patterns constants
144+
FUSED_LAYER_PATTERN = 'fused_layer_pattern'
145+
FUSED_OP_QUANT_CONFIG = 'fused_op_quantization_config'

model_compression_toolkit/core/common/fusion/fusing_info.py

Lines changed: 75 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,8 @@
1414
# ==============================================================================
1515

1616
from model_compression_toolkit.target_platform_capabilities import LayerFilterParams
17+
from model_compression_toolkit.target_platform_capabilities.schema.mct_current_schema import OpQuantizationConfig
18+
from model_compression_toolkit.constants import FUSED_LAYER_PATTERN, FUSED_OP_QUANT_CONFIG
1719
from dataclasses import dataclass, field
1820

1921
from typing import Optional, List, Dict, Any, Tuple
@@ -41,6 +43,7 @@ class FusingInfo:
4143
fusing_patterns: any = None
4244
fusing_data: Dict[str, Tuple['BaseNode']] = field(default_factory=dict)
4345
node_to_fused_node_map: Dict[str, str] = field(init=False, default_factory=dict)
46+
fused_op_id_to_quant_config: Dict[str, OpQuantizationConfig] = field(default_factory=dict)
4447

4548
def __post_init__(self):
4649
"""Validates and initializes mappings after dataclass instantiation."""
@@ -49,6 +52,7 @@ def __post_init__(self):
4952
assert isinstance(op_nodes, tuple) and len(op_nodes) > 1, f"Found invalid fused op nodes: {op_nodes}"
5053

5154
self._init_node_mapping()
55+
self._init_quantization_config_map()
5256

5357
def _init_node_mapping(self) -> None:
5458
"""
@@ -59,6 +63,15 @@ def _init_node_mapping(self) -> None:
5963
for node in nodes:
6064
self.node_to_fused_node_map[node.name] = op_id
6165

66+
def _init_quantization_config_map(self) -> None:
67+
"""
68+
Init the mapping between fused operation IDs and their quantization configurations.
69+
"""
70+
self.fused_op_id_to_quant_config.clear()
71+
if self.fusing_patterns is not None:
72+
for op_id, nodes in self.fusing_data.items():
73+
self.set_fused_op_quantization_config(op_id, nodes)
74+
6275
def add_fused_operation(self, op_id: str, nodes: Tuple['BaseNode']) -> None:
6376
"""
6477
Add a new fused operation with the given ID and set of nodes.
@@ -78,6 +91,22 @@ def add_fused_operation(self, op_id: str, nodes: Tuple['BaseNode']) -> None:
7891
for node in nodes:
7992
self.node_to_fused_node_map[node.name] = op_id
8093

94+
# Update the quantization config mapping for this operation
95+
if self.fusing_patterns is not None:
96+
self.set_fused_op_quantization_config(op_id, nodes)
97+
98+
def set_fused_op_quantization_config(self, op_id: str, nodes: Tuple['BaseNode']) -> None:
99+
"""
100+
Set the quantization configuration for a given fused operation ID.
101+
102+
Args:
103+
op_id (str): The identifier for the fused operation.
104+
nodes (Tuple[BaseNode]): The tuple of nodes that form the fused operation.
105+
"""
106+
fusing_pattern = next((fp for fp in self.fusing_patterns if is_valid_fusion([fp.get(FUSED_LAYER_PATTERN)], nodes)), None)
107+
if fusing_pattern is not None:
108+
self.fused_op_id_to_quant_config[op_id] = fusing_pattern.get(FUSED_OP_QUANT_CONFIG)
109+
81110
def remove_fused_operation(self, op_id: str) -> None:
82111
"""
83112
Remove a fused operation by its ID.
@@ -95,6 +124,7 @@ def remove_fused_operation(self, op_id: str) -> None:
95124
for node in nodes:
96125
self.node_to_fused_node_map.pop(node.name, None)
97126
del self.fusing_data[op_id]
127+
self.fused_op_id_to_quant_config.pop(op_id, None)
98128

99129
def get_fused_node_name(self, node_name: str) -> Optional[str]:
100130
"""
@@ -117,6 +147,15 @@ def get_node_to_fused_node_map(self) -> Dict[str, str]:
117147
"""
118148
return self.node_to_fused_node_map.copy()
119149

150+
def get_fusing_quantization_config_map(self) -> Dict[str, OpQuantizationConfig]:
151+
"""
152+
Retrieve a copy of the mapping from fused operation IDs to their quantization configurations.
153+
154+
Returns:
155+
A dictionary mapping each fused operation ID to its quantization configuration.
156+
"""
157+
return self.fused_op_id_to_quant_config.copy()
158+
120159
def get_fused_nodes(self, op_id: str) -> Optional[List['BaseNode']]:
121160
"""
122161
Retrieve the list of nodes for a given fused operation ID.
@@ -129,6 +168,18 @@ def get_fused_nodes(self, op_id: str) -> Optional[List['BaseNode']]:
129168
"""
130169
return self.fusing_data.get(op_id)
131170

171+
def get_fused_op_quantization_config(self, op_id: str) -> OpQuantizationConfig:
172+
"""
173+
Retrieve the quantization configuration for a given fused operation ID.
174+
175+
Args:
176+
op_id (str): The identifier for the fused operation.
177+
178+
Returns:
179+
OpQuantizationConfig: The quantization configuration for the operation, or None if not found.
180+
"""
181+
return self.fused_op_id_to_quant_config.get(op_id)
182+
132183
def is_node_in_fused_op(self, node: 'BaseNode') -> bool:
133184
"""
134185
Check if a node is part of any fused operation.
@@ -216,10 +267,11 @@ def validate(self, graph: 'Graph') -> None:
216267
all_fused_nodes.update(node_set)
217268

218269
# Check 4: Ensure the sequence matches a valid fusing pattern
219-
if not is_valid_fusion(self.fusing_patterns, nodes):
270+
valid_fusing_patterns = _get_fusing_layer_patterns(self.fusing_patterns)
271+
if not is_valid_fusion(valid_fusing_patterns, nodes):
220272
raise ValueError(
221273
f"Fused operation {op_id} does not match any valid fusing pattern "
222-
f"from {self.fusing_patterns}."
274+
f"from {valid_fusing_patterns}."
223275
)
224276

225277
def is_nodes_eligible_to_be_fused(self, nodes: List['BaseNode']) -> bool:
@@ -240,7 +292,8 @@ def is_nodes_eligible_to_be_fused(self, nodes: List['BaseNode']) -> bool:
240292
return False
241293

242294
# Check if the provided nodes match a valid fusion pattern
243-
return is_valid_fusion(fusing_patterns=self.fusing_patterns, nodes=nodes)
295+
valid_fusing_patterns = _get_fusing_layer_patterns(self.fusing_patterns)
296+
return is_valid_fusion(fusing_patterns=valid_fusing_patterns, nodes=nodes)
244297

245298
def __repr__(self) -> str:
246299
"""
@@ -287,8 +340,11 @@ def generate_fusing_info(self, graph: 'Graph') -> FusingInfo:
287340
if not self._fusing_patterns:
288341
return FusingInfo(fusing_patterns=self._fusing_patterns)
289342

343+
# Extract fusing layer patterns
344+
fusing_layer_patterns = _get_fusing_layer_patterns(self._fusing_patterns)
345+
290346
# Find max fusion
291-
max_layers_fusing = max([len(fusing_pattern) for fusing_pattern in self._fusing_patterns])
347+
max_layer_patterns = max([len(fusing_layer_pattern) for fusing_layer_pattern in fusing_layer_patterns])
292348

293349
# Travel along the graph to find layers for fusing
294350
nodes = graph.get_topo_sorted_nodes()
@@ -302,9 +358,9 @@ def generate_fusing_info(self, graph: 'Graph') -> FusingInfo:
302358
continue
303359
# Start fusing search
304360
fusing_nodes = [] # nodes that are candidates for participating in fusing
305-
patterns = copy.deepcopy(self._fusing_patterns)
361+
patterns = copy.deepcopy(fusing_layer_patterns)
306362
next_nodes = [node]
307-
for i in range(max_layers_fusing):
363+
for i in range(max_layer_patterns):
308364
patterns = get_valid_fusing_patterns_for_node(patterns, next_nodes[0], i)
309365
if len(patterns) == 0: # Give up if no more fusion pattern
310366
break
@@ -314,7 +370,7 @@ def generate_fusing_info(self, graph: 'Graph') -> FusingInfo:
314370
break
315371

316372
# New fusion
317-
if is_valid_fusion(self._fusing_patterns, fusing_nodes):
373+
if is_valid_fusion(fusing_layer_patterns, fusing_nodes):
318374
fused_op_id = FusingInfo.generate_fused_op_id(fusing_nodes)
319375
assert fused_op_id not in fusing_info, f"{fused_op_id} is already in fusing info: {fusing_info}"
320376
fusing_info[fused_op_id] = tuple(fusing_nodes)
@@ -371,3 +427,15 @@ def is_valid_fusion(fusing_patterns: List[List[Any]], nodes: List['BaseNode']) -
371427
if counter == fusion_depth:
372428
return True
373429
return False
430+
431+
432+
def _get_fusing_layer_patterns(fusing_patterns: List[Dict[Any, OpQuantizationConfig]]) -> List[List[Any]]:
433+
"""
434+
Extracts the fusing layer patterns from the provided fusing patterns.
435+
Args:
436+
fusing_patterns: List of patterns of layers/LayerFilterParams to fuse and their mapping quantization config.
437+
438+
Returns:
439+
supported fusing layer patterns
440+
"""
441+
return [f.get(FUSED_LAYER_PATTERN) for f in fusing_patterns]

model_compression_toolkit/target_platform_capabilities/targetplatform2framework/framework_quantization_capabilities.py

Lines changed: 14 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,9 @@
3131
OpQuantizationConfig, QuantizationConfigOptions
3232
from model_compression_toolkit.target_platform_capabilities.targetplatform2framework.current_tpc import _current_tpc
3333

34+
from model_compression_toolkit.constants import FUSED_LAYER_PATTERN, FUSED_OP_QUANT_CONFIG
35+
36+
3437
class FrameworkQuantizationCapabilities(ImmutableClass):
3538
"""
3639
Attach framework information to a modeled hardware.
@@ -94,20 +97,26 @@ def get_layers_by_opset(self, op: OperatorsSetBase) -> List[Any]:
9497
"""
9598
return self.op_sets_to_layers.get_layers_by_op(op)
9699

97-
def get_fusing_patterns(self) -> List[List[Any]]:
100+
def get_fusing_patterns(self) -> List[Dict[List[Any], OpQuantizationConfig]]:
98101
"""
99102
100-
Returns: List of patterns of layers/LayerFilterParams to fuse.
103+
Returns: List of patterns of layers/LayerFilterParams to fuse and their mapping quantization config.
101104
102105
"""
103-
res = []
106+
107+
patterns = []
104108
if self.tpc.fusing_patterns is None:
105-
return res
109+
return patterns
110+
106111
for p in self.tpc.fusing_patterns:
112+
res = []
107113
ops = [self.get_layers_by_opset(x) for x in p.operator_groups]
108114
res.extend(itertools.product(*ops))
109-
return [list(x) for x in res]
110115

116+
fused_op_quant_config = getattr(p, FUSED_OP_QUANT_CONFIG, None)
117+
patterns.extend({FUSED_LAYER_PATTERN: list(x), FUSED_OP_QUANT_CONFIG: fused_op_quant_config} for x in res)
118+
119+
return patterns
111120

112121
def get_info(self) -> Dict[str, Any]:
113122
"""

tests/keras_tests/non_parallel_tests/test_keras_tpc.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -42,7 +42,7 @@
4242
from keras import Input
4343

4444
import model_compression_toolkit as mct
45-
from model_compression_toolkit.constants import TENSORFLOW
45+
from model_compression_toolkit.constants import TENSORFLOW, FUSED_LAYER_PATTERN, FUSED_OP_QUANT_CONFIG
4646
from model_compression_toolkit.target_platform_capabilities.constants import DEFAULT_TP_MODEL, IMX500_TP_MODEL, \
4747
QNNPACK_TP_MODEL, TFLITE_TP_MODEL, KERNEL_ATTR, BIAS_ATTR, KERAS_KERNEL, BIAS, WEIGHTS_N_BITS
4848
from model_compression_toolkit.core.keras.keras_implementation import KerasImplementation
@@ -250,7 +250,7 @@ def test_keras_fusing_patterns(self):
250250

251251
fusings = hm_keras.get_fusing_patterns()
252252
self.assertEqual(len(fusings), 2)
253-
p0, p1 = fusings[0], fusings[1]
253+
p0, p1 = fusings[0].get(FUSED_LAYER_PATTERN), fusings[1].get(FUSED_LAYER_PATTERN)
254254

255255
self.assertEqual(len(p0), 3)
256256
self.assertEqual(p0[0], Conv2D)

tests/pytorch_tests/function_tests/test_pytorch_tpc.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@
2626
import model_compression_toolkit.target_platform_capabilities.schema.mct_current_schema as schema
2727
from model_compression_toolkit.core import MixedPrecisionQuantizationConfig
2828
from model_compression_toolkit.defaultdict import DefaultDict
29-
from model_compression_toolkit.constants import PYTORCH
29+
from model_compression_toolkit.constants import PYTORCH, FUSED_LAYER_PATTERN, FUSED_OP_QUANT_CONFIG
3030
from model_compression_toolkit.core.common import BaseNode
3131
from model_compression_toolkit.target_platform_capabilities.constants import DEFAULT_TP_MODEL, IMX500_TP_MODEL, \
3232
TFLITE_TP_MODEL, QNNPACK_TP_MODEL, KERNEL_ATTR, WEIGHTS_N_BITS, PYTORCH_KERNEL, BIAS_ATTR, BIAS
@@ -236,15 +236,15 @@ def test_pytorch_fusing_patterns(self):
236236
fusing_patterns=tuple(fusing_patterns),
237237
add_metadata=False)
238238

239-
hm_keras = FrameworkQuantizationCapabilities(hm)
240-
with hm_keras:
239+
hm_torch = FrameworkQuantizationCapabilities(hm)
240+
with hm_torch:
241241
OperationsSetToLayers("opA", [torch.conv2d])
242242
OperationsSetToLayers("opB", [torch.tanh])
243243
OperationsSetToLayers("opC", [LayerFilterParams(torch.relu, Greater("max_value", 7), negative_slope=0)])
244244

245-
fusings = hm_keras.get_fusing_patterns()
245+
fusings = hm_torch.get_fusing_patterns()
246246
self.assertEqual(len(fusings), 2)
247-
p0, p1 = fusings[0], fusings[1]
247+
p0, p1 = fusings[0].get(FUSED_LAYER_PATTERN), fusings[1].get(FUSED_LAYER_PATTERN)
248248

249249
self.assertEqual(len(p0), 3)
250250
self.assertEqual(p0[0], torch.conv2d)

tests_pytest/common_tests/unit_tests/core/mixed_precision/resource_utilization_tools/test_resource_utilization_calculator.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@
2020
import pytest
2121
from model_compression_toolkit.core.common.graph.base_graph import OutTensor
2222

23-
from model_compression_toolkit.constants import FLOAT_BITWIDTH
23+
from model_compression_toolkit.constants import FLOAT_BITWIDTH, FUSED_LAYER_PATTERN, FUSED_OP_QUANT_CONFIG
2424
from model_compression_toolkit.core import ResourceUtilization
2525
from model_compression_toolkit.core.common import Graph
2626
from model_compression_toolkit.core.common.fusion.fusing_info import FusingInfo
@@ -555,7 +555,8 @@ def test_compute_cuts_random_fusion_valid_utilization(self, seed, disable_quanti
555555
if i + fuse_len <= num_nodes:
556556
fused = tuple(nodes[j] for j in range(i, i + fuse_len))
557557
fused_name = f"FusedNode_{'_'.join(n.name for n in fused)}"
558-
fused_patterns.append([n.layer_class for n in fused])
558+
fused_pattern = {FUSED_LAYER_PATTERN: [n.layer_class for n in fused], FUSED_OP_QUANT_CONFIG: None}
559+
fused_patterns.append(fused_pattern)
559560
fused_data[fused_name] = fused
560561
i += fuse_len
561562
else:

0 commit comments

Comments
 (0)