1313
1414from __future__ import annotations
1515
16- import copy
1716import functools
18- from typing import Any , Callable , Dict , List , Optional , Set
17+ from typing import Any , Callable , Dict , List , Optional
1918
2019import torch
21- import torch .nn .functional as F
2220from executorch .backends .arm ._passes .arm_pass_manager import ArmPassManager
2321
2422from executorch .backends .arm .quantizer import arm_quantizer_utils
25- from executorch .backends .arm .quantizer .arm_quantizer_utils import (
26- mark_nodes_as_annotated ,
27- propagate_annotation ,
28- )
29- from executorch .backends .arm .quantizer .quantization_annotation import (
30- OP_TO_ANNOTATOR ,
31- OperatorConfig ,
32- OperatorPatternType ,
33- )
23+ from executorch .backends .arm .quantizer .arm_quantizer_utils import mark_node_as_annotated
24+ from executorch .backends .arm .quantizer .quantization_annotator import annotate_graph
25+
3426from executorch .backends .arm .quantizer .quantization_config import QuantizationConfig
3527from torch .ao .quantization .fake_quantize import (
3628 FakeQuantize ,
5850]
5951
6052
61- def _supported_symmetric_quantized_operators () -> Dict [str , List [OperatorPatternType ]]:
62- supported_operators : Dict [str , List [OperatorPatternType ]] = {
63- # Both conv and linear should be able to handle relu + hardtanh fusion since
64- # those are clamp ops
65- "conv2d" : [
66- [torch .nn .Conv2d , torch .nn .ReLU ],
67- [torch .nn .Conv2d , F .relu ],
68- [F .conv2d , torch .nn .ReLU ],
69- [F .conv2d , F .relu ],
70- ],
71- "linear" : [[torch .nn .Linear ], [F .linear ]],
72- "add" : [[torch .add ]],
73- "max_pool2d" : [[torch .nn .MaxPool2d ], [F .max_pool2d ]],
74- "adaptive_avg_pool2d" : [
75- [torch .nn .AdaptiveAvgPool2d ],
76- [F .adaptive_avg_pool2d ],
77- ],
78- "mul" : [[torch .mul ]],
79- "sub" : [[torch .sub ]],
80- "min_max" : [[torch .min ], [torch .max ]],
81- }
82- return copy .deepcopy (supported_operators )
83-
84-
85- def _get_supported_symmetric_config_and_operators () -> List [OperatorConfig ]:
86- supported_config_and_operators : List [OperatorConfig ] = []
87- for quantization_config in [
88- get_symmetric_quantization_config (),
89- get_symmetric_quantization_config (is_per_channel = True ),
90- ]:
91- ops = _supported_symmetric_quantized_operators ()
92- for pattern_list in ops .values ():
93- supported_config_and_operators .append (
94- OperatorConfig (quantization_config , pattern_list )
95- )
96- return copy .deepcopy (supported_config_and_operators )
97-
98-
9953@functools .lru_cache
10054def get_symmetric_quantization_config (
10155 is_per_channel : bool = False ,
@@ -180,10 +134,6 @@ def get_symmetric_quantization_config(
180134 return quantization_config
181135
182136
183- def _get_supported_config_and_operators () -> List [OperatorConfig ]:
184- return _get_supported_symmetric_config_and_operators ()
185-
186-
187137NodeFilterType = Callable [[Node ], bool ]
188138"""Type for a Node Filter used by annotators. A Node filter is a function that takes
189139 a Node and returns whether the node should be annotated or not.
@@ -255,26 +205,6 @@ def not_module_type_or_name_filter(n: Node) -> bool:
255205
256206
257207class ArmQuantizer (Quantizer ):
258- supported_config_and_operators = _get_supported_config_and_operators ()
259-
260- # A list of supported static quantization annotators, in order of application.
261- # For example, fusions come before singular ops.
262- # The name must match the name used when registering the annotator.
263- STATIC_ANNOTATION_ORDER = [
264- "linear" ,
265- "conv" ,
266- "adaptive_avg_pool2d" ,
267- "max_pool2d" ,
268- "add" ,
269- "sub" ,
270- "mul" ,
271- "min_max" ,
272- "mm" ,
273- "one_to_one" ,
274- "generic" ,
275- "upsample_nearest2d" ,
276- ]
277-
278208 def __init__ (self ) -> None :
279209 super ().__init__ ()
280210 self .global_config : Optional [QuantizationConfig ] = None
@@ -331,7 +261,6 @@ def annotate(self, model: GraphModule) -> GraphModule:
331261 The annotated model.
332262 """
333263 model = self ._annotate_for_static_quantization_config (model )
334- propagate_annotation (model )
335264 return model
336265
337266 def _annotate_all_static_patterns (
@@ -353,8 +282,7 @@ def _annotate_all_static_patterns(
353282 if quantization_config is None :
354283 return model
355284
356- for op in self .STATIC_ANNOTATION_ORDER :
357- OP_TO_ANNOTATOR [op ](model , quantization_config , filter_fn )
285+ annotate_graph (model , quantization_config , filter_fn )
358286 return model
359287
360288 def _annotate_for_static_quantization_config (
@@ -363,6 +291,9 @@ def _annotate_for_static_quantization_config(
363291 """Matches the correct QuantizationConfig with the correct module using a filter
364292 when running _annotate_all_static_patterns.
365293 """
294+ if self .io_config :
295+ self ._annotate_io (model , self .io_config )
296+
366297 module_name_list = list (self .module_name_config .keys ())
367298 for module_name , config in self .module_name_config .items ():
368299 self ._annotate_all_static_patterns (
@@ -381,9 +312,6 @@ def _annotate_for_static_quantization_config(
381312 _get_not_module_type_or_name_filter (tp_list , module_name_list ),
382313 )
383314
384- if self .io_config :
385- self ._annotate_io (model , self .io_config )
386-
387315 return model
388316
389317 def _annotate_io (
@@ -399,44 +327,13 @@ def _annotate_io(
399327 node ,
400328 quantization_config .get_output_act_qspec (),
401329 )
402- mark_nodes_as_annotated ([ node ] )
330+ mark_node_as_annotated ( node )
403331 if node .op == "output" :
404332 parent = node .all_input_nodes [0 ]
405333 _annotate_input_qspec_map (
406334 node , parent , quantization_config .get_input_act_qspec ()
407335 )
408- mark_nodes_as_annotated ([ node ] )
336+ mark_node_as_annotated ( node )
409337
410338 def validate (self , model : GraphModule ) -> None :
411339 pass
412-
413- @classmethod
414- def get_supported_operators (cls ) -> List [OperatorConfig ]:
415- return cls .supported_config_and_operators
416-
417- @classmethod
418- def get_supported_quantization_configs (cls ) -> List [QuantizationConfig ]:
419- op_configs : Set [QuantizationConfig ] = set ({})
420- for spec , _ in cls .supported_config_and_operators :
421- op_configs .add (spec )
422- return list (op_configs )
423-
424- @classmethod
425- def get_supported_operator_for_quantization_config (
426- cls , quantization_config : Optional [QuantizationConfig ]
427- ) -> List [OperatorPatternType ]:
428- if quantization_config is None :
429- all_ops = []
430- for _ , ops in cls .supported_config_and_operators :
431- all_ops .extend (ops )
432- return all_ops
433-
434- for config , ops in cls .supported_config_and_operators :
435- # note: this assumes each entry in cls.supported_spec_and_operators
436- # corresponds to one spec, e.g. we don't have
437- # [(spec1, op_list1), (spec1, op_list2), (spec2, op_list3)]
438- # where the first and second entry have the same spec but did not
439- # merge the op list
440- if config == quantization_config :
441- return ops
442- return []
0 commit comments