33#
44# This source code is licensed under the BSD-style license found in the
55# LICENSE file in the root directory of this source tree.
6+ import importlib
7+ from dataclasses import dataclass
68from enum import IntEnum , unique
79from functools import partial
8- from typing import Callable , Optional , Sequence , Set
10+ from typing import Callable , Dict , Optional , Sequence , Set
911
1012import torch
1113from executorch .backends .qualcomm ._passes import (
@@ -66,7 +68,7 @@ class QuantDtype(IntEnum):
6668 use_8a8w = 3
6769
6870
69- quant_config_dict = {
71+ QUANT_CONFIG_DICT = {
7072 # PTQ
7173 (QuantDtype .use_16a16w , False ): (
7274 get_16a16w_qnn_ptq_config ,
@@ -112,18 +114,60 @@ class QuantDtype(IntEnum):
112114}
113115
114116
117+ @dataclass
118+ class ModuleQConfig :
119+ quant_dtype : QuantDtype = QuantDtype .use_8a8w
120+ is_qat : bool = False
121+ is_conv_per_channel : bool = False
122+ is_linear_per_channel : bool = False
123+ act_observer : Optional [
124+ torch .ao .quantization .observer .UniformQuantizationObserverBase
125+ ] = None
126+
127+ def __post_init__ (self ):
128+ if (self .quant_dtype , self .is_qat ) not in QUANT_CONFIG_DICT :
129+ raise RuntimeError (
130+ f"the quant config, (quant_dtype: { self .quant_dtype } , is_qat: { self .is_qat } ) is not support"
131+ )
132+ quant_config_func , per_channel_quant_config_func = QUANT_CONFIG_DICT [
133+ (self .quant_dtype , self .is_qat )
134+ ]
135+ self .quant_config = (
136+ quant_config_func (act_observer = self .act_observer )
137+ if self .act_observer
138+ else quant_config_func ()
139+ )
140+ self .per_channel_quant_config = (
141+ per_channel_quant_config_func (act_observer = self .act_observer )
142+ if self .act_observer
143+ else per_channel_quant_config_func ()
144+ )
145+ self .use_per_channel_weight_quant_ops = set ()
146+ if self .is_conv_per_channel :
147+ self .use_per_channel_weight_quant_ops .update (
148+ {
149+ torch .ops .aten .conv1d .default ,
150+ torch .ops .aten .conv2d .default ,
151+ torch .ops .aten .conv_transpose2d .input ,
152+ }
153+ )
154+ if self .is_linear_per_channel :
155+ self .use_per_channel_weight_quant_ops .update (
156+ {
157+ torch .ops .aten .linear .default ,
158+ }
159+ )
160+
161+
115162class QnnQuantizer (Quantizer ):
116163 SUPPORTED_OPS : Set = set (OP_ANNOTATOR .keys ())
117164
118165 def __init__ (self ):
119166 super ().__init__ ()
120167 self .quant_ops : Set [OpOverload ] = self .SUPPORTED_OPS .copy ()
121168
122- self .is_qat = False
123- self .quant_dtype = QuantDtype .use_8a8w
124- self .quant_config : QuantizationConfig = get_8a8w_qnn_ptq_config ()
125- self .per_channel_quant_config = get_ptq_per_channel_quant_config ()
126- self .use_per_channel_weight_quant_ops : Set [OpOverload ] = set ()
169+ self .default_quant_config = ModuleQConfig ()
170+ self .module_qconfig_dict : Dict [torch .nn .Module , ModuleQConfig ] = {}
127171
128172 self .custom_quant_annotations : Sequence [Callable ] = []
129173 self .discard_nodes : Set [str ] = set ()
@@ -133,37 +177,55 @@ def _annotate(self, gm: GraphModule) -> None:
133177 if node .name in self .discard_nodes :
134178 continue
135179
136- quant_config = self ._get_quant_config (node . target )
180+ quant_config = self ._get_quant_config (node )
137181 if quant_config :
138182 OP_ANNOTATOR [node .target ](node , quant_config )
139183
140184 def _annotate_custom_annotation (self , gm : GraphModule ) -> None :
141185 for annotation_func in self .custom_quant_annotations :
142186 annotation_func (gm )
143187
144- def _get_quant_config (self , op : str | OpOverload ) -> Optional [QuantizationConfig ]:
188+ def _get_submodule (self , node : torch .fx .Node ):
189+ """
190+ An example of nn_module_stack
191+ {
192+ 'L__self__': ('', 'executorch.backends.qualcomm.tests.models.SubModules'),
193+ 'L__self___add': ('add', 'executorch.backends.qualcomm.tests.models.Add')
194+ }
195+ """
196+
197+ nn_module_stack = node .meta .get ("nn_module_stack" )
198+ if nn_module_stack :
199+ module_source_str , module_str = list (nn_module_stack .values ())[- 1 ][
200+ - 1
201+ ].rsplit ("." , 1 )
202+ module_source = importlib .import_module (module_source_str )
203+ return getattr (module_source , module_str )
204+ return None
205+
206+ def _get_quant_config (self , node : torch .fx .Node ) -> Optional [QuantizationConfig ]:
145207 """
146- Priority:
147- 1. is one of use_per_channel_weight_quant_ops
148- 2. quant config
208+ How to pick:
209+ 1. Choose specific submodule config if given.
210+ 2. Pick one if op belongs to use_per_channel_weight_quant_ops
211+ 3. If not 2, pick normal quant config
149212 """
213+ op = node .target
150214 if isinstance (op , str ):
151215 return
152216
153- if op in self .use_per_channel_weight_quant_ops :
154- return self .per_channel_quant_config
217+ config = self .module_qconfig_dict .get (
218+ self ._get_submodule (node ), self .default_quant_config
219+ )
220+
221+ if op in config .use_per_channel_weight_quant_ops :
222+ return config .per_channel_quant_config
155223
156224 if op in self .quant_ops :
157- return self .quant_config
225+ return config .quant_config
158226
159227 print (f"No quant config is implemented for op, { op } " )
160228
161- def _update_per_channel_weight_quant_ops (self , ops : Set [OpOverload ], enable : bool ):
162- if enable :
163- self .use_per_channel_weight_quant_ops .update (ops )
164- else :
165- self .use_per_channel_weight_quant_ops .difference_update (ops )
166-
167229 def add_custom_quant_annotations (
168230 self , custom_quant_annotations : Sequence [Callable ]
169231 ) -> None :
@@ -185,39 +247,29 @@ def annotate(self, model: GraphModule) -> GraphModule:
185247 def get_supported_ops (self ) -> Set [OpOverload ]:
186248 return self .SUPPORTED_OPS
187249
188- def set_quant_config (
189- self , quant_dtype : QuantDtype , is_qat = False , act_observer = None
250+ def set_default_quant_config (
251+ self ,
252+ quant_dtype : QuantDtype ,
253+ is_qat = False ,
254+ is_conv_per_channel = False ,
255+ is_linear_per_channel = False ,
256+ act_observer = None ,
190257 ) -> None :
191- self .quant_dtype = quant_dtype
192- self .is_qat = is_qat
193- if (quant_dtype , is_qat ) not in quant_config_dict :
194- raise RuntimeError (
195- f"the quant config, (quant_dtype: { quant_dtype } , is_qat: { is_qat } ) is not support"
196- )
197-
198- quant_config_fuc , per_channel_quant_config_fuc = quant_config_dict [
199- (quant_dtype , is_qat )
200- ]
201- self .quant_config = (
202- quant_config_fuc (act_observer = act_observer )
203- if act_observer
204- else quant_config_fuc ()
205- )
206- self .per_channel_quant_config = (
207- per_channel_quant_config_fuc (act_observer = act_observer )
208- if act_observer
209- else per_channel_quant_config_fuc ()
258+ self .default_quant_config = ModuleQConfig (
259+ quant_dtype ,
260+ is_qat ,
261+ is_conv_per_channel ,
262+ is_linear_per_channel ,
263+ act_observer ,
210264 )
211265
212- def set_per_channel_conv_quant (self , enable : bool ) -> None :
213- conv_ops = {torch .ops .aten .conv1d .default , torch .ops .aten .conv2d .default }
214- self ._update_per_channel_weight_quant_ops (conv_ops , enable )
215-
216- def set_per_channel_linear_quant (self , enable : bool ) -> None :
217- linear_ops = {
218- torch .ops .aten .linear .default ,
219- }
220- self ._update_per_channel_weight_quant_ops (linear_ops , enable )
266+ def set_submodule_quant_config (
267+ self , submodule : torch .nn .Module , module_qconfig : ModuleQConfig
268+ ) -> None :
269+ """
270+ Set the quant config specific for a submodule
271+ """
272+ self .module_qconfig_dict [submodule ] = module_qconfig
221273
222274 def transform_for_annotation (self , model : GraphModule ) -> GraphModule :
223275 model = ReduceDynamicRange ()(model ).graph_module
0 commit comments