1414# ==============================================================================
1515
1616from model_compression_toolkit .target_platform_capabilities import LayerFilterParams
17+ from model_compression_toolkit .target_platform_capabilities .schema .mct_current_schema import OpQuantizationConfig
18+ from model_compression_toolkit .constants import FUSED_LAYER_PATTERN , FUSED_OP_QUANT_CONFIG
1719from dataclasses import dataclass , field
1820
1921from typing import Optional , List , Dict , Any , Tuple
@@ -41,6 +43,7 @@ class FusingInfo:
4143 fusing_patterns : any = None
4244 fusing_data : Dict [str , Tuple ['BaseNode' ]] = field (default_factory = dict )
4345 node_to_fused_node_map : Dict [str , str ] = field (init = False , default_factory = dict )
46+ fused_op_id_to_quant_config : Dict [str , OpQuantizationConfig ] = field (default_factory = dict )
4447
4548 def __post_init__ (self ):
4649 """Validates and initializes mappings after dataclass instantiation."""
@@ -49,6 +52,7 @@ def __post_init__(self):
4952 assert isinstance (op_nodes , tuple ) and len (op_nodes ) > 1 , f"Found invalid fused op nodes: { op_nodes } "
5053
5154 self ._init_node_mapping ()
55+ self ._init_quantization_config_map ()
5256
5357 def _init_node_mapping (self ) -> None :
5458 """
@@ -59,6 +63,15 @@ def _init_node_mapping(self) -> None:
5963 for node in nodes :
6064 self .node_to_fused_node_map [node .name ] = op_id
6165
66+ def _init_quantization_config_map (self ) -> None :
67+ """
68+ Init the mapping between fused operation IDs and their quantization configurations.
69+ """
70+ self .fused_op_id_to_quant_config .clear ()
71+ if self .fusing_patterns is not None :
72+ for op_id , nodes in self .fusing_data .items ():
73+ self .set_fused_op_quantization_config (op_id , nodes )
74+
6275 def add_fused_operation (self , op_id : str , nodes : Tuple ['BaseNode' ]) -> None :
6376 """
6477 Add a new fused operation with the given ID and set of nodes.
@@ -78,6 +91,22 @@ def add_fused_operation(self, op_id: str, nodes: Tuple['BaseNode']) -> None:
7891 for node in nodes :
7992 self .node_to_fused_node_map [node .name ] = op_id
8093
94+ # Update the quantization config mapping for this operation
95+ if self .fusing_patterns is not None :
96+ self .set_fused_op_quantization_config (op_id , nodes )
97+
98+ def set_fused_op_quantization_config (self , op_id : str , nodes : Tuple ['BaseNode' ]) -> None :
99+ """
100+ Set the quantization configuration for a given fused operation ID.
101+
102+ Args:
103+ op_id (str): The identifier for the fused operation.
104+ nodes (Tuple[BaseNode]): The tuple of nodes that form the fused operation.
105+ """
106+ fusing_pattern = next ((fp for fp in self .fusing_patterns if is_valid_fusion ([fp .get (FUSED_LAYER_PATTERN )], nodes )), None )
107+ if fusing_pattern is not None :
108+ self .fused_op_id_to_quant_config [op_id ] = fusing_pattern .get (FUSED_OP_QUANT_CONFIG )
109+
81110 def remove_fused_operation (self , op_id : str ) -> None :
82111 """
83112 Remove a fused operation by its ID.
@@ -95,6 +124,7 @@ def remove_fused_operation(self, op_id: str) -> None:
95124 for node in nodes :
96125 self .node_to_fused_node_map .pop (node .name , None )
97126 del self .fusing_data [op_id ]
127+ self .fused_op_id_to_quant_config .pop (op_id , None )
98128
99129 def get_fused_node_name (self , node_name : str ) -> Optional [str ]:
100130 """
@@ -117,6 +147,15 @@ def get_node_to_fused_node_map(self) -> Dict[str, str]:
117147 """
118148 return self .node_to_fused_node_map .copy ()
119149
150+ def get_fusing_quantization_config_map (self ) -> Dict [str , OpQuantizationConfig ]:
151+ """
152+ Retrieve a copy of the mapping from fused operation IDs to their quantization configurations.
153+
154+ Returns:
155+ A dictionary mapping each fused operation ID to its quantization configuration.
156+ """
157+ return self .fused_op_id_to_quant_config .copy ()
158+
120159 def get_fused_nodes (self , op_id : str ) -> Optional [List ['BaseNode' ]]:
121160 """
122161 Retrieve the list of nodes for a given fused operation ID.
@@ -129,6 +168,18 @@ def get_fused_nodes(self, op_id: str) -> Optional[List['BaseNode']]:
129168 """
130169 return self .fusing_data .get (op_id )
131170
171+ def get_fused_op_quantization_config (self , op_id : str ) -> OpQuantizationConfig :
172+ """
173+ Retrieve the quantization configuration for a given fused operation ID.
174+
175+ Args:
176+ op_id (str): The identifier for the fused operation.
177+
178+ Returns:
179+ OpQuantizationConfig: The quantization configuration for the operation, or None if not found.
180+ """
181+ return self .fused_op_id_to_quant_config .get (op_id )
182+
132183 def is_node_in_fused_op (self , node : 'BaseNode' ) -> bool :
133184 """
134185 Check if a node is part of any fused operation.
@@ -216,10 +267,11 @@ def validate(self, graph: 'Graph') -> None:
216267 all_fused_nodes .update (node_set )
217268
218269 # Check 4: Ensure the sequence matches a valid fusing pattern
219- if not is_valid_fusion (self .fusing_patterns , nodes ):
270+ valid_fusing_patterns = _get_fusing_layer_patterns (self .fusing_patterns )
271+ if not is_valid_fusion (valid_fusing_patterns , nodes ):
220272 raise ValueError (
221273 f"Fused operation { op_id } does not match any valid fusing pattern "
222- f"from { self . fusing_patterns } ."
274+ f"from { valid_fusing_patterns } ."
223275 )
224276
225277 def is_nodes_eligible_to_be_fused (self , nodes : List ['BaseNode' ]) -> bool :
@@ -240,7 +292,8 @@ def is_nodes_eligible_to_be_fused(self, nodes: List['BaseNode']) -> bool:
240292 return False
241293
242294 # Check if the provided nodes match a valid fusion pattern
243- return is_valid_fusion (fusing_patterns = self .fusing_patterns , nodes = nodes )
295+ valid_fusing_patterns = _get_fusing_layer_patterns (self .fusing_patterns )
296+ return is_valid_fusion (fusing_patterns = valid_fusing_patterns , nodes = nodes )
244297
245298 def __repr__ (self ) -> str :
246299 """
@@ -287,8 +340,11 @@ def generate_fusing_info(self, graph: 'Graph') -> FusingInfo:
287340 if not self ._fusing_patterns :
288341 return FusingInfo (fusing_patterns = self ._fusing_patterns )
289342
343+ # Extract fusing layer patterns
344+ fusing_layer_patterns = _get_fusing_layer_patterns (self ._fusing_patterns )
345+
290346 # Find max fusion
291- max_layers_fusing = max ([len (fusing_pattern ) for fusing_pattern in self . _fusing_patterns ])
347+ max_layer_patterns = max ([len (fusing_layer_pattern ) for fusing_layer_pattern in fusing_layer_patterns ])
292348
293349 # Travel along the graph to find layers for fusing
294350 nodes = graph .get_topo_sorted_nodes ()
@@ -302,9 +358,9 @@ def generate_fusing_info(self, graph: 'Graph') -> FusingInfo:
302358 continue
303359 # Start fusing search
304360 fusing_nodes = [] # nodes that are candidates for participating in fusing
305- patterns = copy .deepcopy (self . _fusing_patterns )
361+ patterns = copy .deepcopy (fusing_layer_patterns )
306362 next_nodes = [node ]
307- for i in range (max_layers_fusing ):
363+ for i in range (max_layer_patterns ):
308364 patterns = get_valid_fusing_patterns_for_node (patterns , next_nodes [0 ], i )
309365 if len (patterns ) == 0 : # Give up if no more fusion pattern
310366 break
@@ -314,7 +370,7 @@ def generate_fusing_info(self, graph: 'Graph') -> FusingInfo:
314370 break
315371
316372 # New fusion
317- if is_valid_fusion (self . _fusing_patterns , fusing_nodes ):
373+ if is_valid_fusion (fusing_layer_patterns , fusing_nodes ):
318374 fused_op_id = FusingInfo .generate_fused_op_id (fusing_nodes )
319375 assert fused_op_id not in fusing_info , f"{ fused_op_id } is already in fusing info: { fusing_info } "
320376 fusing_info [fused_op_id ] = tuple (fusing_nodes )
@@ -371,3 +427,15 @@ def is_valid_fusion(fusing_patterns: List[List[Any]], nodes: List['BaseNode']) -
371427 if counter == fusion_depth :
372428 return True
373429 return False
430+
431+
432+ def _get_fusing_layer_patterns (fusing_patterns : List [Dict [Any , OpQuantizationConfig ]]) -> List [List [Any ]]:
433+ """
434+ Extracts the fusing layer patterns from the provided fusing patterns.
435+ Args:
436+ fusing_patterns: List of patterns of layers/LayerFilterParams to fuse and their mapping quantization config.
437+
438+ Returns:
439+ supported fusing layer patterns
440+ """
441+ return [f .get (FUSED_LAYER_PATTERN ) for f in fusing_patterns ]
0 commit comments