33#
44# This source code is licensed under the BSD-style license found in the
55# LICENSE file in the root directory of this source tree.
6+ from dataclasses import dataclass
67from enum import IntEnum , unique
78from functools import partial
8- from typing import Callable , Dict , Optional , Sequence , Set , Tuple
9+ from typing import Callable , Dict , List , Optional , Sequence , Set , Tuple
910
1011import torch
1112from executorch .backends .qualcomm ._passes .qnn_pass_manager import QnnPassManager
@@ -58,7 +59,7 @@ class QuantDtype(IntEnum):
5859 use_8a8w = 4
5960
6061
61- quant_config_dict = {
62+ QUANT_CONFIG_DICT = {
6263 # PTQ
6364 (QuantDtype .use_16a16w , False ): (
6465 get_16a16w_qnn_ptq_config ,
@@ -123,21 +124,71 @@ class QuantDtype(IntEnum):
123124}
124125
125126
127+ @dataclass
128+ class ModuleQConfig :
129+ quant_dtype : QuantDtype = QuantDtype .use_8a8w
130+ is_qat : bool = False
131+ is_conv_per_channel : bool = False
132+ is_linear_per_channel : bool = False
133+ act_observer : Optional [
134+ torch .ao .quantization .observer .UniformQuantizationObserverBase
135+ ] = None
136+
137+ def __post_init__ (self ):
138+ if (self .quant_dtype , self .is_qat ) not in QUANT_CONFIG_DICT :
139+ raise RuntimeError (
140+ f"the quant config, (quant_dtype: { self .quant_dtype } , is_qat: { self .is_qat } ) is not support"
141+ )
142+ (
143+ quant_config_func ,
144+ per_channel_quant_config_func ,
145+ per_block_quant_config_func ,
146+ ) = QUANT_CONFIG_DICT [(self .quant_dtype , self .is_qat )]
147+ self .quant_config = (
148+ quant_config_func (act_observer = self .act_observer )
149+ if self .act_observer
150+ else quant_config_func ()
151+ )
152+ self .per_channel_quant_config = (
153+ per_channel_quant_config_func (act_observer = self .act_observer )
154+ if self .act_observer
155+ else per_channel_quant_config_func ()
156+ )
157+ self .use_per_channel_weight_quant_ops = set ()
158+ if self .is_conv_per_channel :
159+ self .use_per_channel_weight_quant_ops .update (
160+ {
161+ torch .ops .aten .conv1d .default ,
162+ torch .ops .aten .conv2d .default ,
163+ torch .ops .aten .conv_transpose2d .input ,
164+ }
165+ )
166+ if self .is_linear_per_channel :
167+ self .use_per_channel_weight_quant_ops .update (
168+ {
169+ torch .ops .aten .linear .default ,
170+ }
171+ )
172+ if per_block_quant_config_func :
173+ self .per_block_quant_config = (
174+ per_block_quant_config_func (act_observer = self .act_observer )
175+ if self .act_observer
176+ else per_block_quant_config_func ()
177+ )
178+
179+
126180class QnnQuantizer (Quantizer ):
127181 SUPPORTED_OPS : Set = set (OP_ANNOTATOR .keys ())
128182
129183 def __init__ (self ):
130184 super ().__init__ ()
131185 self .quant_ops : Set [OpOverload ] = self .SUPPORTED_OPS .copy ()
132186
133- self .is_qat = False
134- self .quant_dtype = QuantDtype .use_8a8w
135- self .quant_config : QuantizationConfig = get_8a8w_qnn_ptq_config ()
136- self .per_channel_quant_config = get_ptq_per_channel_quant_config ()
137- self .per_block_quant_config = get_ptq_per_block_quant_config ()
187+ self .default_quant_config = ModuleQConfig ()
188+ self .submodule_qconfig_list : List [
189+ Tuple [Callable [[torch .fx .Node ], bool ], ModuleQConfig ]
190+ ] = []
138191 self .block_size_map = {}
139- self .use_per_channel_weight_quant_ops : Set [OpOverload ] = set ()
140- self .use_per_block_weight_quant_ops : Set [OpOverload ] = set ()
141192
142193 self .custom_quant_annotations : Sequence [Callable ] = []
143194 self .discard_nodes : Set [str ] = set ()
@@ -155,41 +206,38 @@ def _annotate_custom_annotation(self, gm: GraphModule) -> None:
155206 for annotation_func in self .custom_quant_annotations :
156207 annotation_func (gm )
157208
158- def _get_quant_config (self , op : torch .fx .Node ) -> Optional [QuantizationConfig ]:
209+ def _get_submodule_qconfig (self , node : torch .fx .Node ):
210+ for func , qconfig in self .submodule_qconfig_list :
211+ if func (node ):
212+ return qconfig
213+ return self .default_quant_config
214+
215+ def _get_quant_config (self , node : torch .fx .Node ) -> Optional [QuantizationConfig ]:
159216 """
160- Priority:
161- 1. is one of use_per_block_weight_quant_ops
162- 2. is one of use_per_channel_weight_quant_ops
163- 3. quant config
217+ How to pick:
218+ 1. is one of per_block_quant_config
219+ 2. Pick specific submodule config if given.
220+ 3. Pick one if op belongs to use_per_channel_weight_quant_ops
221+ 4. If not 3, pick normal quant config
164222 """
165- target = op .target
166- if isinstance (target , str ):
223+ op = node .target
224+ if isinstance (op , str ):
167225 return
168226
169- if target in self .use_per_block_weight_quant_ops :
170- if block_size : = self .block_size_map . get ( op . name ):
171- self . per_block_quant_config .block_size = block_size
172- return self . per_block_quant_config
227+ if block_size := self .block_size_map . get ( node . name ) :
228+ config = self .default_quant_config . per_block_quant_config
229+ config .block_size = block_size
230+ return config
173231
174- if target in self .use_per_channel_weight_quant_ops :
175- return self .per_channel_quant_config
232+ config = self ._get_submodule_qconfig (node )
176233
177- if target in self . quant_ops :
178- return self . quant_config
234+ if op in config . use_per_channel_weight_quant_ops :
235+ return config . per_channel_quant_config
179236
180- print (f"No quant config is implemented for op, { op } " )
181-
182- def _update_per_block_weight_quant_ops (self , ops : Set [OpOverload ], enable : bool ):
183- if enable :
184- self .use_per_block_weight_quant_ops .update (ops )
185- else :
186- self .use_per_block_weight_quant_ops .difference_update (ops )
237+ if op in self .quant_ops :
238+ return config .quant_config
187239
188- def _update_per_channel_weight_quant_ops (self , ops : Set [OpOverload ], enable : bool ):
189- if enable :
190- self .use_per_channel_weight_quant_ops .update (ops )
191- else :
192- self .use_per_channel_weight_quant_ops .difference_update (ops )
240+ print (f"No quant config is implemented for op, { op } " )
193241
194242 def add_custom_quant_annotations (
195243 self , custom_quant_annotations : Sequence [Callable ]
@@ -212,55 +260,74 @@ def annotate(self, model: GraphModule) -> GraphModule:
212260 def get_supported_ops (self ) -> Set [OpOverload ]:
213261 return self .SUPPORTED_OPS
214262
215- def set_quant_config (
216- self , quant_dtype : QuantDtype , is_qat = False , act_observer = None
263+ def set_default_quant_config (
264+ self ,
265+ quant_dtype : QuantDtype ,
266+ is_qat = False ,
267+ is_conv_per_channel = False ,
268+ is_linear_per_channel = False ,
269+ act_observer = None ,
217270 ) -> None :
218- self .quant_dtype = quant_dtype
219- self .is_qat = is_qat
220- if (quant_dtype , is_qat ) not in quant_config_dict :
221- raise RuntimeError (
222- f"the quant config, (quant_dtype: { quant_dtype } , is_qat: { is_qat } ) is not support"
223- )
224-
225- quant_config_fuc , per_channel_quant_config_fuc , per_block_quant_config_fuc = (
226- quant_config_dict [(quant_dtype , is_qat )]
227- )
228- self .quant_config = (
229- quant_config_fuc (act_observer = act_observer )
230- if act_observer
231- else quant_config_fuc ()
271+ self .default_quant_config = ModuleQConfig (
272+ quant_dtype ,
273+ is_qat ,
274+ is_conv_per_channel ,
275+ is_linear_per_channel ,
276+ act_observer ,
232277 )
233- self .per_channel_quant_config = (
234- per_channel_quant_config_fuc (act_observer = act_observer )
235- if act_observer
236- else per_channel_quant_config_fuc ()
237- )
238- if per_block_quant_config_fuc is not None :
239- self .per_block_quant_config = (
240- per_block_quant_config_fuc (act_observer = act_observer )
241- if act_observer
242- else per_block_quant_config_fuc ()
243- )
244278
245279 def set_block_size_map (self , block_size_map : Dict [str , Tuple ]) -> None :
246280 self .block_size_map = block_size_map
247281
248- def set_per_block_conv_quant (self , enable : bool ) -> None :
249- conv_ops = {torch .ops .aten .conv2d .default }
250- self ._update_per_block_weight_quant_ops (conv_ops , enable )
251-
252- def set_per_channel_conv_quant (self , enable : bool ) -> None :
253- conv_ops = {torch .ops .aten .conv1d .default , torch .ops .aten .conv2d .default }
254- self ._update_per_channel_weight_quant_ops (conv_ops , enable )
255-
256- def set_per_channel_linear_quant (self , enable : bool ) -> None :
257- linear_ops = {
258- torch .ops .aten .linear .default ,
259- }
260- self ._update_per_channel_weight_quant_ops (linear_ops , enable )
282+ def set_submodule_qconfig_list (
283+ self , submodule_qconfig_list : List [Tuple [Callable , ModuleQConfig ]]
284+ ) -> None :
285+ """
286+ Set specific quant config from a callback function.
287+ If a node fits more than one callback, only apply the first one.
288+ """
289+ self .submodule_qconfig_list = submodule_qconfig_list
261290
262291 def transform_for_annotation (self , model : GraphModule ) -> GraphModule :
263292 return QnnPassManager ().transform_for_annotation_pipeline (model )
264293
265294 def validate (self , model : GraphModule ) -> None :
266295 pass
296+
297+
298+ def get_submodule_type_predicate (module_type_str ):
299+ """
300+ An example of nn_module_stack
301+ {
302+ 'L__self__': ('', 'executorch.backends.qualcomm.tests.models.SubModules'),
303+ 'L__self___add': ('add', 'executorch.backends.qualcomm.tests.models.Add')
304+ }
305+ """
306+
307+ def predicate (node ):
308+ if nn_module_stack := node .meta .get ("nn_module_stack" ):
309+ for _ , type_name in nn_module_stack .values ():
310+ if module_type_str in type_name :
311+ return True
312+ return False
313+
314+ return predicate
315+
316+
317+ def get_submodule_name_predicate (module_name_str ):
318+ """
319+ An example of nn_module_stack
320+ {
321+ 'L__self__': ('', 'executorch.backends.qualcomm.tests.models.SubModules'),
322+ 'L__self___add': ('add', 'executorch.backends.qualcomm.tests.models.Add')
323+ }
324+ """
325+
326+ def predicate (node ):
327+ if nn_module_stack := node .meta .get ("nn_module_stack" ):
328+ for name in nn_module_stack .keys ():
329+ if module_name_str in name :
330+ return True
331+ return False
332+
333+ return predicate
0 commit comments