33#
44# This source code is licensed under the BSD-style license found in the
55# LICENSE file in the root directory of this source tree.
6- import importlib
76from dataclasses import dataclass
87from enum import IntEnum , unique
98from functools import partial
10- from typing import Callable , Dict , Optional , Sequence , Set , Tuple
9+ from typing import Callable , Dict , List , Optional , Sequence , Set , Tuple
1110
1211import torch
1312from executorch .backends .qualcomm ._passes import (
@@ -153,9 +152,11 @@ def __post_init__(self):
153152 raise RuntimeError (
154153 f"the quant config, (quant_dtype: { self .quant_dtype } , is_qat: { self .is_qat } ) is not support"
155154 )
156- quant_config_func , per_channel_quant_config_func , per_block_quant_config_func = QUANT_CONFIG_DICT [
157- (self .quant_dtype , self .is_qat )
158- ]
155+ (
156+ quant_config_func ,
157+ per_channel_quant_config_func ,
158+ per_block_quant_config_func ,
159+ ) = QUANT_CONFIG_DICT [(self .quant_dtype , self .is_qat )]
159160 self .quant_config = (
160161 quant_config_func (act_observer = self .act_observer )
161162 if self .act_observer
@@ -197,7 +198,9 @@ def __init__(self):
197198 self .quant_ops : Set [OpOverload ] = self .SUPPORTED_OPS .copy ()
198199
199200 self .default_quant_config = ModuleQConfig ()
200- self .module_qconfig_dict : Dict [torch .nn .Module , ModuleQConfig ] = {}
201+ self .submodule_qconfig_list : List [
202+ Tuple [Callable [[torch .fx .Node ], bool ], ModuleQConfig ]
203+ ] = []
201204 self .block_size_map = {}
202205
203206 self .custom_quant_annotations : Sequence [Callable ] = []
@@ -216,44 +219,30 @@ def _annotate_custom_annotation(self, gm: GraphModule) -> None:
216219 for annotation_func in self .custom_quant_annotations :
217220 annotation_func (gm )
218221
219- def _get_submodule (self , node : torch .fx .Node ):
220- """
221- An example of nn_module_stack
222- {
223- 'L__self__': ('', 'executorch.backends.qualcomm.tests.models.SubModules'),
224- 'L__self___add': ('add', 'executorch.backends.qualcomm.tests.models.Add')
225- }
226- """
227-
228- nn_module_stack = node .meta .get ("nn_module_stack" )
229- if nn_module_stack :
230- module_source_str , module_str = list (nn_module_stack .values ())[- 1 ][
231- - 1
232- ].rsplit ("." , 1 )
233- module_source = importlib .import_module (module_source_str )
234- return getattr (module_source , module_str )
235- return None
222+ def _get_submodule_qconfig (self , node : torch .fx .Node ):
223+ for func , qconfig in self .submodule_qconfig_list :
224+ if func (node ):
225+ return qconfig
226+ return self .default_quant_config
236227
237228 def _get_quant_config (self , node : torch .fx .Node ) -> Optional [QuantizationConfig ]:
238229 """
239230 How to pick:
240- 1. is one of use_per_block_weight_quant_ops
241- 2. Choose specific submodule config if given.
231+ 1. is one of per_block_quant_config
232+ 2. Pick specific submodule config if given.
242233 3. Pick one if op belongs to use_per_channel_weight_quant_ops
243- 4. If not 2 , pick normal quant config
234+ 4. If not 3 , pick normal quant config
244235 """
245236 op = node .target
246237 if isinstance (op , str ):
247238 return
248239
249- if block_size := self .block_size_map .get (op .name ):
240+ if block_size := self .block_size_map .get (node .name ):
250241 config = self .default_quant_config .per_block_quant_config
251242 config .block_size = block_size
252243 return config
253244
254- config = self .module_qconfig_dict .get (
255- self ._get_submodule (node ), self .default_quant_config
256- )
245+ config = self ._get_submodule_qconfig (node )
257246
258247 if op in config .use_per_channel_weight_quant_ops :
259248 return config .per_channel_quant_config
@@ -303,13 +292,14 @@ def set_default_quant_config(
303292 def set_block_size_map (self , block_size_map : Dict [str , Tuple ]) -> None :
304293 self .block_size_map = block_size_map
305294
306- def set_submodule_quant_config (
307- self , submodule : torch . nn . Module , module_qconfig : ModuleQConfig
295+ def set_submodule_qconfig_list (
296+ self , submodule_qconfig_list : List [ Tuple [ Callable , ModuleQConfig ]]
308297 ) -> None :
309298 """
310- Set the quant config specific for a submodule
299+ Set specific quant config from a callback function.
300+ If a node fits more than one callback, only apply the first one.
311301 """
312- self .module_qconfig_dict [ submodule ] = module_qconfig
302+ self .submodule_qconfig_list = submodule_qconfig_list
313303
314304 def transform_for_annotation (self , model : GraphModule ) -> GraphModule :
315305 model = ReduceDynamicRange ()(model ).graph_module
@@ -326,3 +316,41 @@ def transform_for_annotation(self, model: GraphModule) -> GraphModule:
326316
327317 def validate (self , model : GraphModule ) -> None :
328318 pass
319+
320+
321+ def get_submodule_type_predicate (module_type_str ):
322+ """
323+ An example of nn_module_stack
324+ {
325+ 'L__self__': ('', 'executorch.backends.qualcomm.tests.models.SubModules'),
326+ 'L__self___add': ('add', 'executorch.backends.qualcomm.tests.models.Add')
327+ }
328+ """
329+
330+ def predicate (node ):
331+ if nn_module_stack := node .meta .get ("nn_module_stack" ):
332+ for _ , type_name in nn_module_stack .values ():
333+ if module_type_str in type_name :
334+ return True
335+ return False
336+
337+ return predicate
338+
339+
340+ def get_submodule_name_predicate (module_name_str ):
341+ """
342+ An example of nn_module_stack
343+ {
344+ 'L__self__': ('', 'executorch.backends.qualcomm.tests.models.SubModules'),
345+ 'L__self___add': ('add', 'executorch.backends.qualcomm.tests.models.Add')
346+ }
347+ """
348+
349+ def predicate (node ):
350+ if nn_module_stack := node .meta .get ("nn_module_stack" ):
351+ for name in nn_module_stack .keys ():
352+ if module_name_str in name :
353+ return True
354+ return False
355+
356+ return predicate
0 commit comments