33#
44# This source code is licensed under the BSD-style license found in the
55# LICENSE file in the root directory of this source tree.
6+ import importlib
7+ from dataclasses import dataclass
68from enum import IntEnum , unique
79from functools import partial
810from typing import Callable , Dict , Optional , Sequence , Set , Tuple
@@ -58,7 +60,7 @@ class QuantDtype(IntEnum):
5860 use_8a8w = 4
5961
6062
61- quant_config_dict = {
63+ QUANT_CONFIG_DICT = {
6264 # PTQ
6365 (QuantDtype .use_16a16w , False ): (
6466 get_16a16w_qnn_ptq_config ,
@@ -123,21 +125,66 @@ class QuantDtype(IntEnum):
123125}
124126
125127
128+ @dataclass
129+ class ModuleQConfig :
130+ quant_dtype : QuantDtype = QuantDtype .use_8a8w
131+ is_qat : bool = False
132+ is_conv_per_channel : bool = False
133+ is_linear_per_channel : bool = False
134+ act_observer : Optional [
135+ torch .ao .quantization .observer .UniformQuantizationObserverBase
136+ ] = None
137+
138+ def __post_init__ (self ):
139+ if (self .quant_dtype , self .is_qat ) not in QUANT_CONFIG_DICT :
140+ raise RuntimeError (
141+ f"the quant config, (quant_dtype: { self .quant_dtype } , is_qat: { self .is_qat } ) is not support"
142+ )
143+ quant_config_func , per_channel_quant_config_func , per_block_quant_config_func = QUANT_CONFIG_DICT [
144+ (self .quant_dtype , self .is_qat )
145+ ]
146+ self .quant_config = (
147+ quant_config_func (act_observer = self .act_observer )
148+ if self .act_observer
149+ else quant_config_func ()
150+ )
151+ self .per_channel_quant_config = (
152+ per_channel_quant_config_func (act_observer = self .act_observer )
153+ if self .act_observer
154+ else per_channel_quant_config_func ()
155+ )
156+ self .per_block_quant_config = (
157+ per_block_quant_config_func (act_observer = act_observer )
158+ if self .act_observer
159+ else per_block_quant_config_func ()
160+ )
161+ self .use_per_channel_weight_quant_ops = set ()
162+ if self .is_conv_per_channel :
163+ self .use_per_channel_weight_quant_ops .update (
164+ {
165+ torch .ops .aten .conv1d .default ,
166+ torch .ops .aten .conv2d .default ,
167+ torch .ops .aten .conv_transpose2d .input ,
168+ }
169+ )
170+ if self .is_linear_per_channel :
171+ self .use_per_channel_weight_quant_ops .update (
172+ {
173+ torch .ops .aten .linear .default ,
174+ }
175+ )
176+
177+
126178class QnnQuantizer (Quantizer ):
127179 SUPPORTED_OPS : Set = set (OP_ANNOTATOR .keys ())
128180
129181 def __init__ (self ):
130182 super ().__init__ ()
131183 self .quant_ops : Set [OpOverload ] = self .SUPPORTED_OPS .copy ()
132184
133- self .is_qat = False
134- self .quant_dtype = QuantDtype .use_8a8w
135- self .quant_config : QuantizationConfig = get_8a8w_qnn_ptq_config ()
136- self .per_channel_quant_config = get_ptq_per_channel_quant_config ()
137- self .per_block_quant_config = get_ptq_per_block_quant_config ()
185+ self .default_quant_config = ModuleQConfig ()
186+ self .module_qconfig_dict : Dict [torch .nn .Module , ModuleQConfig ] = {}
138187 self .block_size_map = {}
139- self .use_per_channel_weight_quant_ops : Set [OpOverload ] = set ()
140- self .use_per_block_weight_quant_ops : Set [OpOverload ] = set ()
141188
142189 self .custom_quant_annotations : Sequence [Callable ] = []
143190 self .discard_nodes : Set [str ] = set ()
@@ -155,41 +202,52 @@ def _annotate_custom_annotation(self, gm: GraphModule) -> None:
155202 for annotation_func in self .custom_quant_annotations :
156203 annotation_func (gm )
157204
158- def _get_quant_config (self , op : torch .fx .Node ) -> Optional [QuantizationConfig ]:
205+ def _get_submodule (self , node : torch .fx .Node ):
206+ """
207+ An example of nn_module_stack
208+ {
209+ 'L__self__': ('', 'executorch.backends.qualcomm.tests.models.SubModules'),
210+ 'L__self___add': ('add', 'executorch.backends.qualcomm.tests.models.Add')
211+ }
159212 """
160- Priority:
213+
214+ nn_module_stack = node .meta .get ("nn_module_stack" )
215+ if nn_module_stack :
216+ module_source_str , module_str = list (nn_module_stack .values ())[- 1 ][
217+ - 1
218+ ].rsplit ("." , 1 )
219+ module_source = importlib .import_module (module_source_str )
220+ return getattr (module_source , module_str )
221+ return None
222+
223+ def _get_quant_config (self , node : torch .fx .Node ) -> Optional [QuantizationConfig ]:
224+ """
225+ How to pick:
161226 1. is one of use_per_block_weight_quant_ops
162- 2. is one of use_per_channel_weight_quant_ops
163- 3. quant config
227+ 2. Choose specific submodule config if given.
228+ 3. Pick one if op belongs to use_per_channel_weight_quant_ops
229+ 4. If not 2, pick normal quant config
164230 """
165- target = op .target
166- if isinstance (target , str ):
231+ op = node .target
232+ if isinstance (op , str ):
167233 return
168234
169- if target in self .use_per_block_weight_quant_ops :
170- if block_size := self .block_size_map .get (op .name ):
171- self .per_block_quant_config .block_size = block_size
172- return self .per_block_quant_config
173-
174- if target in self .use_per_channel_weight_quant_ops :
175- return self .per_channel_quant_config
235+ if block_size := self .block_size_map .get (op .name ):
236+ config = self .default_quant_config .per_block_quant_config
237+ config .block_size = block_size
238+ return config
176239
177- if target in self .quant_ops :
178- return self .quant_config
240+ config = self .module_qconfig_dict .get (
241+ self ._get_submodule (node ), self .default_quant_config
242+ )
179243
180- print (f"No quant config is implemented for op, { op } " )
244+ if op in config .use_per_channel_weight_quant_ops :
245+ return config .per_channel_quant_config
181246
182- def _update_per_block_weight_quant_ops (self , ops : Set [OpOverload ], enable : bool ):
183- if enable :
184- self .use_per_block_weight_quant_ops .update (ops )
185- else :
186- self .use_per_block_weight_quant_ops .difference_update (ops )
247+ if op in self .quant_ops :
248+ return config .quant_config
187249
188- def _update_per_channel_weight_quant_ops (self , ops : Set [OpOverload ], enable : bool ):
189- if enable :
190- self .use_per_channel_weight_quant_ops .update (ops )
191- else :
192- self .use_per_channel_weight_quant_ops .difference_update (ops )
250+ print (f"No quant config is implemented for op, { op } " )
193251
194252 def add_custom_quant_annotations (
195253 self , custom_quant_annotations : Sequence [Callable ]
@@ -212,52 +270,32 @@ def annotate(self, model: GraphModule) -> GraphModule:
212270 def get_supported_ops (self ) -> Set [OpOverload ]:
213271 return self .SUPPORTED_OPS
214272
215- def set_quant_config (
216- self , quant_dtype : QuantDtype , is_qat = False , act_observer = None
273+ def set_default_quant_config (
274+ self ,
275+ quant_dtype : QuantDtype ,
276+ is_qat = False ,
277+ is_conv_per_channel = False ,
278+ is_linear_per_channel = False ,
279+ act_observer = None ,
217280 ) -> None :
218- self .quant_dtype = quant_dtype
219- self .is_qat = is_qat
220- if (quant_dtype , is_qat ) not in quant_config_dict :
221- raise RuntimeError (
222- f"the quant config, (quant_dtype: { quant_dtype } , is_qat: { is_qat } ) is not support"
223- )
224-
225- quant_config_fuc , per_channel_quant_config_fuc , per_block_quant_config_fuc = (
226- quant_config_dict [(quant_dtype , is_qat )]
227- )
228- self .quant_config = (
229- quant_config_fuc (act_observer = act_observer )
230- if act_observer
231- else quant_config_fuc ()
281+ self .default_quant_config = ModuleQConfig (
282+ quant_dtype ,
283+ is_qat ,
284+ is_conv_per_channel ,
285+ is_linear_per_channel ,
286+ act_observer ,
232287 )
233- self .per_channel_quant_config = (
234- per_channel_quant_config_fuc (act_observer = act_observer )
235- if act_observer
236- else per_channel_quant_config_fuc ()
237- )
238- if per_block_quant_config_fuc is not None :
239- self .per_block_quant_config = (
240- per_block_quant_config_fuc (act_observer = act_observer )
241- if act_observer
242- else per_block_quant_config_fuc ()
243- )
244288
245289 def set_block_size_map (self , block_size_map : Dict [str , Tuple ]) -> None :
246290 self .block_size_map = block_size_map
247291
248- def set_per_block_conv_quant (self , enable : bool ) -> None :
249- conv_ops = {torch .ops .aten .conv2d .default }
250- self ._update_per_block_weight_quant_ops (conv_ops , enable )
251-
252- def set_per_channel_conv_quant (self , enable : bool ) -> None :
253- conv_ops = {torch .ops .aten .conv1d .default , torch .ops .aten .conv2d .default }
254- self ._update_per_channel_weight_quant_ops (conv_ops , enable )
255-
256- def set_per_channel_linear_quant (self , enable : bool ) -> None :
257- linear_ops = {
258- torch .ops .aten .linear .default ,
259- }
260- self ._update_per_channel_weight_quant_ops (linear_ops , enable )
292+ def set_submodule_quant_config (
293+ self , submodule : torch .nn .Module , module_qconfig : ModuleQConfig
294+ ) -> None :
295+ """
296+ Set the quant config specific for a submodule
297+ """
298+ self .module_qconfig_dict [submodule ] = module_qconfig
261299
262300 def transform_for_annotation (self , model : GraphModule ) -> GraphModule :
263301 return QnnPassManager ().transform_for_annotation_pipeline (model )
0 commit comments