33#
44# This source code is licensed under the BSD-style license found in the
55# LICENSE file in the root directory of this source tree.
6+ import importlib
7+ from dataclasses import dataclass
68from enum import IntEnum , unique
79from functools import partial
810from typing import Callable , Dict , Optional , Sequence , Set , Tuple
@@ -71,7 +73,7 @@ class QuantDtype(IntEnum):
7173 use_8a8w = 4
7274
7375
74- quant_config_dict = {
76+ QUANT_CONFIG_DICT = {
7577 # PTQ
7678 (QuantDtype .use_16a16w , False ): (
7779 get_16a16w_qnn_ptq_config ,
@@ -136,21 +138,66 @@ class QuantDtype(IntEnum):
136138}
137139
138140
141+ @dataclass
142+ class ModuleQConfig :
143+ quant_dtype : QuantDtype = QuantDtype .use_8a8w
144+ is_qat : bool = False
145+ is_conv_per_channel : bool = False
146+ is_linear_per_channel : bool = False
147+ act_observer : Optional [
148+ torch .ao .quantization .observer .UniformQuantizationObserverBase
149+ ] = None
150+
151+ def __post_init__ (self ):
152+ if (self .quant_dtype , self .is_qat ) not in QUANT_CONFIG_DICT :
153+ raise RuntimeError (
154+ f"the quant config, (quant_dtype: { self .quant_dtype } , is_qat: { self .is_qat } ) is not support"
155+ )
156+ quant_config_func , per_channel_quant_config_func , per_block_quant_config_func = QUANT_CONFIG_DICT [
157+ (self .quant_dtype , self .is_qat )
158+ ]
159+ self .quant_config = (
160+ quant_config_func (act_observer = self .act_observer )
161+ if self .act_observer
162+ else quant_config_func ()
163+ )
164+ self .per_channel_quant_config = (
165+ per_channel_quant_config_func (act_observer = self .act_observer )
166+ if self .act_observer
167+ else per_channel_quant_config_func ()
168+ )
169+ self .per_block_quant_config = (
170+ per_block_quant_config_func (act_observer = act_observer )
171+ if self .act_observer
172+ else per_block_quant_config_func ()
173+ )
174+ self .use_per_channel_weight_quant_ops = set ()
175+ if self .is_conv_per_channel :
176+ self .use_per_channel_weight_quant_ops .update (
177+ {
178+ torch .ops .aten .conv1d .default ,
179+ torch .ops .aten .conv2d .default ,
180+ torch .ops .aten .conv_transpose2d .input ,
181+ }
182+ )
183+ if self .is_linear_per_channel :
184+ self .use_per_channel_weight_quant_ops .update (
185+ {
186+ torch .ops .aten .linear .default ,
187+ }
188+ )
189+
190+
139191class QnnQuantizer (Quantizer ):
140192 SUPPORTED_OPS : Set = set (OP_ANNOTATOR .keys ())
141193
142194 def __init__ (self ):
143195 super ().__init__ ()
144196 self .quant_ops : Set [OpOverload ] = self .SUPPORTED_OPS .copy ()
145197
146- self .is_qat = False
147- self .quant_dtype = QuantDtype .use_8a8w
148- self .quant_config : QuantizationConfig = get_8a8w_qnn_ptq_config ()
149- self .per_channel_quant_config = get_ptq_per_channel_quant_config ()
150- self .per_block_quant_config = get_ptq_per_block_quant_config ()
198+ self .default_quant_config = ModuleQConfig ()
199+ self .module_qconfig_dict : Dict [torch .nn .Module , ModuleQConfig ] = {}
151200 self .block_size_map = {}
152- self .use_per_channel_weight_quant_ops : Set [OpOverload ] = set ()
153- self .use_per_block_weight_quant_ops : Set [OpOverload ] = set ()
154201
155202 self .custom_quant_annotations : Sequence [Callable ] = []
156203 self .discard_nodes : Set [str ] = set ()
@@ -168,41 +215,52 @@ def _annotate_custom_annotation(self, gm: GraphModule) -> None:
168215 for annotation_func in self .custom_quant_annotations :
169216 annotation_func (gm )
170217
171- def _get_quant_config (self , op : torch .fx .Node ) -> Optional [QuantizationConfig ]:
218+ def _get_submodule (self , node : torch .fx .Node ):
219+ """
220+ An example of nn_module_stack
221+ {
222+ 'L__self__': ('', 'executorch.backends.qualcomm.tests.models.SubModules'),
223+ 'L__self___add': ('add', 'executorch.backends.qualcomm.tests.models.Add')
224+ }
172225 """
173- Priority:
226+
227+ nn_module_stack = node .meta .get ("nn_module_stack" )
228+ if nn_module_stack :
229+ module_source_str , module_str = list (nn_module_stack .values ())[- 1 ][
230+ - 1
231+ ].rsplit ("." , 1 )
232+ module_source = importlib .import_module (module_source_str )
233+ return getattr (module_source , module_str )
234+ return None
235+
236+ def _get_quant_config (self , node : torch .fx .Node ) -> Optional [QuantizationConfig ]:
237+ """
238+ How to pick:
174239 1. is one of use_per_block_weight_quant_ops
175- 2. is one of use_per_channel_weight_quant_ops
176- 3. quant config
240+ 2. Choose specific submodule config if given.
241+ 3. Pick one if op belongs to use_per_channel_weight_quant_ops
242+ 4. If not 2, pick normal quant config
177243 """
178- target = op .target
179- if isinstance (target , str ):
244+ op = node .target
245+ if isinstance (op , str ):
180246 return
181247
182- if target in self .use_per_block_weight_quant_ops :
183- if block_size := self .block_size_map .get (op .name ):
184- self .per_block_quant_config .block_size = block_size
185- return self .per_block_quant_config
186-
187- if target in self .use_per_channel_weight_quant_ops :
188- return self .per_channel_quant_config
248+ if block_size := self .block_size_map .get (op .name ):
249+ config = self .default_quant_config .per_block_quant_config
250+ config .block_size = block_size
251+ return config
189252
190- if target in self .quant_ops :
191- return self .quant_config
253+ config = self .module_qconfig_dict .get (
254+ self ._get_submodule (node ), self .default_quant_config
255+ )
192256
193- print (f"No quant config is implemented for op, { op } " )
257+ if op in config .use_per_channel_weight_quant_ops :
258+ return config .per_channel_quant_config
194259
195- def _update_per_block_weight_quant_ops (self , ops : Set [OpOverload ], enable : bool ):
196- if enable :
197- self .use_per_block_weight_quant_ops .update (ops )
198- else :
199- self .use_per_block_weight_quant_ops .difference_update (ops )
260+ if op in self .quant_ops :
261+ return config .quant_config
200262
201- def _update_per_channel_weight_quant_ops (self , ops : Set [OpOverload ], enable : bool ):
202- if enable :
203- self .use_per_channel_weight_quant_ops .update (ops )
204- else :
205- self .use_per_channel_weight_quant_ops .difference_update (ops )
263+ print (f"No quant config is implemented for op, { op } " )
206264
207265 def add_custom_quant_annotations (
208266 self , custom_quant_annotations : Sequence [Callable ]
@@ -225,52 +283,32 @@ def annotate(self, model: GraphModule) -> GraphModule:
225283 def get_supported_ops (self ) -> Set [OpOverload ]:
226284 return self .SUPPORTED_OPS
227285
228- def set_quant_config (
229- self , quant_dtype : QuantDtype , is_qat = False , act_observer = None
286+ def set_default_quant_config (
287+ self ,
288+ quant_dtype : QuantDtype ,
289+ is_qat = False ,
290+ is_conv_per_channel = False ,
291+ is_linear_per_channel = False ,
292+ act_observer = None ,
230293 ) -> None :
231- self .quant_dtype = quant_dtype
232- self .is_qat = is_qat
233- if (quant_dtype , is_qat ) not in quant_config_dict :
234- raise RuntimeError (
235- f"the quant config, (quant_dtype: { quant_dtype } , is_qat: { is_qat } ) is not support"
236- )
237-
238- quant_config_fuc , per_channel_quant_config_fuc , per_block_quant_config_fuc = (
239- quant_config_dict [(quant_dtype , is_qat )]
240- )
241- self .quant_config = (
242- quant_config_fuc (act_observer = act_observer )
243- if act_observer
244- else quant_config_fuc ()
294+ self .default_quant_config = ModuleQConfig (
295+ quant_dtype ,
296+ is_qat ,
297+ is_conv_per_channel ,
298+ is_linear_per_channel ,
299+ act_observer ,
245300 )
246- self .per_channel_quant_config = (
247- per_channel_quant_config_fuc (act_observer = act_observer )
248- if act_observer
249- else per_channel_quant_config_fuc ()
250- )
251- if per_block_quant_config_fuc is not None :
252- self .per_block_quant_config = (
253- per_block_quant_config_fuc (act_observer = act_observer )
254- if act_observer
255- else per_block_quant_config_fuc ()
256- )
257301
258302 def set_block_size_map (self , block_size_map : Dict [str , Tuple ]) -> None :
259303 self .block_size_map = block_size_map
260304
261- def set_per_block_conv_quant (self , enable : bool ) -> None :
262- conv_ops = {torch .ops .aten .conv2d .default }
263- self ._update_per_block_weight_quant_ops (conv_ops , enable )
264-
265- def set_per_channel_conv_quant (self , enable : bool ) -> None :
266- conv_ops = {torch .ops .aten .conv1d .default , torch .ops .aten .conv2d .default }
267- self ._update_per_channel_weight_quant_ops (conv_ops , enable )
268-
269- def set_per_channel_linear_quant (self , enable : bool ) -> None :
270- linear_ops = {
271- torch .ops .aten .linear .default ,
272- }
273- self ._update_per_channel_weight_quant_ops (linear_ops , enable )
305+ def set_submodule_quant_config (
306+ self , submodule : torch .nn .Module , module_qconfig : ModuleQConfig
307+ ) -> None :
308+ """
309+ Set the quant config specific for a submodule
310+ """
311+ self .module_qconfig_dict [submodule ] = module_qconfig
274312
275313 def transform_for_annotation (self , model : GraphModule ) -> GraphModule :
276314 model = ReduceDynamicRange ()(model ).graph_module
0 commit comments