33#
44# This source code is licensed under the BSD-style license found in the
55# LICENSE file in the root directory of this source tree.
6- import importlib
76from dataclasses import dataclass
87from enum import IntEnum , unique
98from functools import partial
10- from typing import Callable , Dict , Optional , Sequence , Set , Tuple
9+ from typing import Callable , Dict , List , Optional , Sequence , Set , Tuple
1110
1211import torch
1312from executorch .backends .qualcomm ._passes .qnn_pass_manager import QnnPassManager
@@ -140,9 +139,11 @@ def __post_init__(self):
140139 raise RuntimeError (
141140 f"the quant config, (quant_dtype: { self .quant_dtype } , is_qat: { self .is_qat } ) is not support"
142141 )
143- quant_config_func , per_channel_quant_config_func , per_block_quant_config_func = QUANT_CONFIG_DICT [
144- (self .quant_dtype , self .is_qat )
145- ]
142+ (
143+ quant_config_func ,
144+ per_channel_quant_config_func ,
145+ per_block_quant_config_func ,
146+ ) = QUANT_CONFIG_DICT [(self .quant_dtype , self .is_qat )]
146147 self .quant_config = (
147148 quant_config_func (act_observer = self .act_observer )
148149 if self .act_observer
@@ -184,7 +185,9 @@ def __init__(self):
184185 self .quant_ops : Set [OpOverload ] = self .SUPPORTED_OPS .copy ()
185186
186187 self .default_quant_config = ModuleQConfig ()
187- self .module_qconfig_dict : Dict [torch .nn .Module , ModuleQConfig ] = {}
188+ self .submodule_qconfig_list : List [
189+ Tuple [Callable [[torch .fx .Node ], bool ], ModuleQConfig ]
190+ ] = []
188191 self .block_size_map = {}
189192
190193 self .custom_quant_annotations : Sequence [Callable ] = []
@@ -203,44 +206,30 @@ def _annotate_custom_annotation(self, gm: GraphModule) -> None:
203206 for annotation_func in self .custom_quant_annotations :
204207 annotation_func (gm )
205208
206- def _get_submodule (self , node : torch .fx .Node ):
207- """
208- An example of nn_module_stack
209- {
210- 'L__self__': ('', 'executorch.backends.qualcomm.tests.models.SubModules'),
211- 'L__self___add': ('add', 'executorch.backends.qualcomm.tests.models.Add')
212- }
213- """
214-
215- nn_module_stack = node .meta .get ("nn_module_stack" )
216- if nn_module_stack :
217- module_source_str , module_str = list (nn_module_stack .values ())[- 1 ][
218- - 1
219- ].rsplit ("." , 1 )
220- module_source = importlib .import_module (module_source_str )
221- return getattr (module_source , module_str )
222- return None
209+ def _get_submodule_qconfig (self , node : torch .fx .Node ):
210+ for func , qconfig in self .submodule_qconfig_list :
211+ if func (node ):
212+ return qconfig
213+ return self .default_quant_config
223214
224215 def _get_quant_config (self , node : torch .fx .Node ) -> Optional [QuantizationConfig ]:
225216 """
226217 How to pick:
227- 1. is one of use_per_block_weight_quant_ops
228- 2. Choose specific submodule config if given.
218+ 1. is one of per_block_quant_config
219+ 2. Pick specific submodule config if given.
229220 3. Pick one if op belongs to use_per_channel_weight_quant_ops
230- 4. If not 2 , pick normal quant config
221+ 4. If not 3 , pick normal quant config
231222 """
232223 op = node .target
233224 if isinstance (op , str ):
234225 return
235226
236- if block_size := self .block_size_map .get (op .name ):
227+ if block_size := self .block_size_map .get (node .name ):
237228 config = self .default_quant_config .per_block_quant_config
238229 config .block_size = block_size
239230 return config
240231
241- config = self .module_qconfig_dict .get (
242- self ._get_submodule (node ), self .default_quant_config
243- )
232+ config = self ._get_submodule_qconfig (node )
244233
245234 if op in config .use_per_channel_weight_quant_ops :
246235 return config .per_channel_quant_config
@@ -290,16 +279,55 @@ def set_default_quant_config(
290279 def set_block_size_map (self , block_size_map : Dict [str , Tuple ]) -> None :
291280 self .block_size_map = block_size_map
292281
293- def set_submodule_quant_config (
294- self , submodule : torch . nn . Module , module_qconfig : ModuleQConfig
282+ def set_submodule_qconfig_list (
283+ self , submodule_qconfig_list : List [ Tuple [ Callable , ModuleQConfig ]]
295284 ) -> None :
296285 """
297- Set the quant config specific for a submodule
286+ Set specific quant config from a callback function.
287+ If a node fits more than one callback, only apply the first one.
298288 """
299- self .module_qconfig_dict [ submodule ] = module_qconfig
289+ self .submodule_qconfig_list = submodule_qconfig_list
300290
301291 def transform_for_annotation (self , model : GraphModule ) -> GraphModule :
302292 return QnnPassManager ().transform_for_annotation_pipeline (model )
303293
304294 def validate (self , model : GraphModule ) -> None :
305295 pass
296+
297+
298+ def get_submodule_type_predicate (module_type_str ):
299+ """
300+ An example of nn_module_stack
301+ {
302+ 'L__self__': ('', 'executorch.backends.qualcomm.tests.models.SubModules'),
303+ 'L__self___add': ('add', 'executorch.backends.qualcomm.tests.models.Add')
304+ }
305+ """
306+
307+ def predicate (node ):
308+ if nn_module_stack := node .meta .get ("nn_module_stack" ):
309+ for _ , type_name in nn_module_stack .values ():
310+ if module_type_str in type_name :
311+ return True
312+ return False
313+
314+ return predicate
315+
316+
317+ def get_submodule_name_predicate (module_name_str ):
318+ """
319+ An example of nn_module_stack
320+ {
321+ 'L__self__': ('', 'executorch.backends.qualcomm.tests.models.SubModules'),
322+ 'L__self___add': ('add', 'executorch.backends.qualcomm.tests.models.Add')
323+ }
324+ """
325+
326+ def predicate (node ):
327+ if nn_module_stack := node .meta .get ("nn_module_stack" ):
328+ for name in nn_module_stack .keys ():
329+ if module_name_str in name :
330+ return True
331+ return False
332+
333+ return predicate
0 commit comments