@@ -177,6 +177,29 @@ def __post_init__(self):
177177
178178
179179class QnnQuantizer (Quantizer ):
180+ """
181+ QnnQuantizer is a quantization annotator designed for QNN backends.
182+ It uses OP_ANNOTATOR, a dictionary mapping OpOverload to annotator functions,
183+ to determine how each node should be annotated for quantization.
184+
185+ Example usage:
186+ quantizer = QnnQuantizer()
187+ quantizer.set_default_quant_config(
188+ quant_dtype=QuantDtype.use_8a8w,
189+ is_qat=False,
190+ is_conv_per_channel=True,
191+ is_linear_per_channel=True,
192+ act_observer=MovingAverageMinMaxObserver,
193+ )
194+ quantizer.set_block_size_map({"conv2d": (1, 128, 1, 1)})
195+ quantizer.set_submodule_qconfig_list([
196+ (get_submodule_type_predicate("Add"), ModuleQConfig(quant_dtype=QuantDtype.use_16a4w))
197+ ])
198+ quantizer.add_custom_quant_annotations(...)
199+ quantizer.add_discard_nodes([node.name to skip annotation])
200+ quantizer.add_discard_ops([node.target to skip annotation])
201+ """
202+
180203 SUPPORTED_OPS : Set = set (OP_ANNOTATOR .keys ())
181204
182205 def __init__ (self ):
@@ -193,6 +216,11 @@ def __init__(self):
193216 self .discard_nodes : Set [str ] = set ()
194217
195218 def _annotate (self , gm : GraphModule ) -> None :
219+ """
220+ Annotates the nodes of the provided GraphModule in-place based on user defined quant configs during prepare_pt2e.
221+
222+ For each node in the graph, nodes without quant config or those explicitly listed in `self.discard_nodes` are not annotated.
223+ """
196224 for node in gm .graph .nodes :
197225 if node .name in self .discard_nodes :
198226 continue
@@ -206,18 +234,34 @@ def _annotate_custom_annotation(self, gm: GraphModule) -> None:
206234 annotation_func (gm )
207235
208236 def _get_submodule_qconfig (self , node : torch .fx .Node ):
237+ """
238+ Retrieves the `ModuleQConfig` for a given node by matching the first applicable callable function in the `submodule_qconfig_list`.
239+ You can add submodule-specific quant config using the `set_submodule_qconfig_list` method.
240+
241+ Args:
242+ node (torch.fx.Node): The node for which to retrieve the quant config.
243+
244+ Returns:
245+ ModuleQConfig: The matched submodule config, or the default config if no match is found.
246+ """
209247 for func , qconfig in self .submodule_qconfig_list :
210248 if func (node ):
211249 return qconfig
212250 return self .default_quant_config
213251
214252 def _get_quant_config (self , node : torch .fx .Node ) -> Optional [QuantizationConfig ]:
215253 """
216- How to pick:
217- 1. is one of per_block_quant_config
218- 2. Pick specific submodule config if given.
219- 3. Pick one if op belongs to use_per_channel_weight_quant_ops
220- 4. If not 3, pick normal quant config
254+ Select the quant config for a node based on priority.
255+
256+ Priority order:
257+ 1. Per-block quant config if block_size is set for node.
258+ 2. Submodule-specific config if predicate matches.
259+ 3. Per-channel config if op is in per-channel set.
260+ 4. Default quant config if op is supported.
261+
262+ Args:
263+ node (torch.fx.Node): The node to get quant config for.
264+
221265 """
222266 op = node .target
223267 if isinstance (op , str ):
@@ -241,22 +285,49 @@ def _get_quant_config(self, node: torch.fx.Node) -> Optional[QuantizationConfig]
241285 def add_custom_quant_annotations (
242286 self , custom_quant_annotations : Sequence [Callable ]
243287 ) -> None :
288+ """
289+ Add custom annotation functions to be applied during prepare_pt2e.
290+
291+ Args:
292+ custom_quant_annotations (Sequence[Callable]): A sequence of functions that take a GraphModule and perform custom annotation.
293+ """
244294 self .custom_quant_annotations = custom_quant_annotations
245295
246296 def add_discard_nodes (self , nodes : Sequence [str ]) -> None :
297+ """
298+ Specifies node IDs to exclude from quantization.
299+ """
247300 self .discard_nodes = set (nodes )
248301
249302 def add_discard_ops (self , ops : Sequence [OpOverload ]) -> None :
303+ """
304+ Specifies OpOverloads to exclude from quantization.
305+ """
250306 for op in ops :
251307 self .quant_ops .remove (op )
252308
253309 def annotate (self , model : GraphModule ) -> GraphModule :
310+ """
311+ Annotates GraphModule during prepare_pt2e.
312+
313+ Args:
314+ model (GraphModule): The FX GraphModule to annotate.
315+
316+ Returns:
317+ GraphModule: The annotated model.
318+ """
254319 self ._annotate (model )
255320 self ._annotate_custom_annotation (model )
256321
257322 return model
258323
259324 def get_supported_ops (self ) -> Set [OpOverload ]:
325+ """
326+ Returns the set of supported OpOverloads for quantization.
327+
328+ Returns:
329+ Set[OpOverload]: Supported ops.
330+ """
260331 return self .SUPPORTED_OPS
261332
262333 def set_default_quant_config (
@@ -267,6 +338,17 @@ def set_default_quant_config(
267338 is_linear_per_channel = False ,
268339 act_observer = None ,
269340 ) -> None :
341+ """
342+ Set the default quant config for quantizer.
343+
344+ Args:
345+ quant_dtype (QuantDtype): Specifies the quantized data type. By default, 8-bit activations and weights (8a8w) are used.
346+ is_qat (bool, optional): Enables Quantization-Aware Training (QAT) mode. Defaults to Post-Training Quantization (PTQ) mode.
347+ is_conv_per_channel (bool, optional): Enables per-channel quantization for convolution operations.
348+ is_linear_per_channel (bool, optional): Enables per-channel quantization for linear (fully connected) operations.
349+ act_observer (Optional[UniformQuantizationObserverBase], optional): Custom observer for activation quantization. If not specified, the default observer is determined by `QUANT_CONFIG_DICT`.
350+
351+ """
270352 self .default_quant_config = ModuleQConfig (
271353 quant_dtype ,
272354 is_qat ,
@@ -276,6 +358,12 @@ def set_default_quant_config(
276358 )
277359
278360 def set_block_size_map (self , block_size_map : Dict [str , Tuple ]) -> None :
361+ """
362+ Set the mapping from node names to block sizes for per-block quantization.
363+
364+ Args:
365+ block_size_map (Dict[str, Tuple]): Mapping from node name to block size.
366+ """
279367 self .block_size_map = block_size_map
280368
281369 def set_submodule_qconfig_list (
@@ -288,6 +376,15 @@ def set_submodule_qconfig_list(
288376 self .submodule_qconfig_list = submodule_qconfig_list
289377
290378 def transform_for_annotation (self , model : GraphModule ) -> GraphModule :
379+ """
380+ Applies QNN-specific transformation before annotation during prepare_pt2e.
381+
382+ Args:
383+ model (GraphModule): The FX GraphModule to transform.
384+
385+ Returns:
386+ GraphModule: The transformed model.
387+ """
291388 return QnnPassManager ().transform_for_annotation_pipeline (model )
292389
293390 def validate (self , model : GraphModule ) -> None :
0 commit comments