1313
1414from __future__ import annotations
1515
16- import copy
1716import functools
18- from typing import Any , Callable , Dict , List , Optional , Set
17+ from typing import Any , Callable , Dict , List , Optional
1918
2019import torch
21- import torch .nn .functional as F
2220from executorch .backends .arm ._passes .arm_pass_manager import ArmPassManager
2321
2422from executorch .backends .arm .quantizer import arm_quantizer_utils
25- from executorch .backends .arm .quantizer .arm_quantizer_utils import (
26- mark_nodes_as_annotated ,
27- propagate_annotation ,
28- )
29- from executorch .backends .arm .quantizer .quantization_annotation import (
30- OP_TO_ANNOTATOR ,
31- OperatorConfig ,
32- OperatorPatternType ,
33- )
23+ from executorch .backends .arm .quantizer .arm_quantizer_utils import mark_node_as_annotated
24+ from executorch .backends .arm .quantizer .quantization_annotator import annotate_graph
25+
3426from executorch .backends .arm .quantizer .quantization_config import QuantizationConfig
3527from torch .ao .quantization .fake_quantize import (
3628 FakeQuantize ,
5850]
5951
6052
61- def _supported_symmetric_quantized_operators () -> Dict [str , List [OperatorPatternType ]]:
62- supported_operators : Dict [str , List [OperatorPatternType ]] = {
63- # Both conv and linear should be able to handle relu + hardtanh fusion since
64- # those are clamp ops
65- "conv2d" : [
66- [torch .nn .Conv2d , torch .nn .ReLU ],
67- [torch .nn .Conv2d , F .relu ],
68- [F .conv2d , torch .nn .ReLU ],
69- [F .conv2d , F .relu ],
70- ],
71- "linear" : [[torch .nn .Linear ], [F .linear ]],
72- "add" : [[torch .add ]],
73- "max_pool2d" : [[torch .nn .MaxPool2d ], [F .max_pool2d ]],
74- "adaptive_avg_pool2d" : [
75- [torch .nn .AdaptiveAvgPool2d ],
76- [F .adaptive_avg_pool2d ],
77- ],
78- "mul" : [[torch .mul ]],
79- "sub" : [[torch .sub ]],
80- }
81- return copy .deepcopy (supported_operators )
82-
83-
84- def _get_supported_symmetric_config_and_operators () -> List [OperatorConfig ]:
85- supported_config_and_operators : List [OperatorConfig ] = []
86- for quantization_config in [
87- get_symmetric_quantization_config (),
88- get_symmetric_quantization_config (is_per_channel = True ),
89- ]:
90- ops = _supported_symmetric_quantized_operators ()
91- for pattern_list in ops .values ():
92- supported_config_and_operators .append (
93- OperatorConfig (quantization_config , pattern_list )
94- )
95- return copy .deepcopy (supported_config_and_operators )
96-
97-
9853@functools .lru_cache
9954def get_symmetric_quantization_config (
10055 is_per_channel : bool = False ,
@@ -179,10 +134,6 @@ def get_symmetric_quantization_config(
179134 return quantization_config
180135
181136
182- def _get_supported_config_and_operators () -> List [OperatorConfig ]:
183- return _get_supported_symmetric_config_and_operators ()
184-
185-
186137NodeFilterType = Callable [[Node ], bool ]
187138"""Type for a Node Filter used by annotators. A Node filter is a function that takes
188139 a Node and returns whether the node should be annotated or not.
@@ -254,25 +205,6 @@ def not_module_type_or_name_filter(n: Node) -> bool:
254205
255206
256207class ArmQuantizer (Quantizer ):
257- supported_config_and_operators = _get_supported_config_and_operators ()
258-
259- # A list of supported static quantization annotators, in order of application.
260- # For example, fusions come before singular ops.
261- # The name must match the name used when registering the annotator.
262- STATIC_ANNOTATION_ORDER = [
263- "linear" ,
264- "conv" ,
265- "adaptive_avg_pool2d" ,
266- "max_pool2d" ,
267- "add" ,
268- "sub" ,
269- "mul" ,
270- "mm" ,
271- "one_to_one" ,
272- "generic" ,
273- "upsample_nearest2d" ,
274- ]
275-
276208 def __init__ (self ) -> None :
277209 super ().__init__ ()
278210 self .global_config : Optional [QuantizationConfig ] = None
@@ -329,7 +261,6 @@ def annotate(self, model: GraphModule) -> GraphModule:
329261 The annotated model.
330262 """
331263 model = self ._annotate_for_static_quantization_config (model )
332- propagate_annotation (model )
333264 return model
334265
335266 def _annotate_all_static_patterns (
@@ -351,8 +282,7 @@ def _annotate_all_static_patterns(
351282 if quantization_config is None :
352283 return model
353284
354- for op in self .STATIC_ANNOTATION_ORDER :
355- OP_TO_ANNOTATOR [op ](model , quantization_config , filter_fn )
285+ annotate_graph (model , quantization_config , filter_fn )
356286 return model
357287
358288 def _annotate_for_static_quantization_config (
@@ -361,6 +291,9 @@ def _annotate_for_static_quantization_config(
361291 """Matches the correct QuantizationConfig with the correct module using a filter
362292 when running _annotate_all_static_patterns.
363293 """
294+ if self .io_config :
295+ self ._annotate_io (model , self .io_config )
296+
364297 module_name_list = list (self .module_name_config .keys ())
365298 for module_name , config in self .module_name_config .items ():
366299 self ._annotate_all_static_patterns (
@@ -379,9 +312,6 @@ def _annotate_for_static_quantization_config(
379312 _get_not_module_type_or_name_filter (tp_list , module_name_list ),
380313 )
381314
382- if self .io_config :
383- self ._annotate_io (model , self .io_config )
384-
385315 return model
386316
387317 def _annotate_io (
@@ -397,44 +327,13 @@ def _annotate_io(
397327 node ,
398328 quantization_config .get_output_act_qspec (),
399329 )
400- mark_nodes_as_annotated ([ node ] )
330+ mark_node_as_annotated ( node )
401331 if node .op == "output" :
402332 parent = node .all_input_nodes [0 ]
403333 _annotate_input_qspec_map (
404334 node , parent , quantization_config .get_input_act_qspec ()
405335 )
406- mark_nodes_as_annotated ([ node ] )
336+ mark_node_as_annotated ( node )
407337
408338 def validate (self , model : GraphModule ) -> None :
409339 pass
410-
411- @classmethod
412- def get_supported_operators (cls ) -> List [OperatorConfig ]:
413- return cls .supported_config_and_operators
414-
415- @classmethod
416- def get_supported_quantization_configs (cls ) -> List [QuantizationConfig ]:
417- op_configs : Set [QuantizationConfig ] = set ({})
418- for spec , _ in cls .supported_config_and_operators :
419- op_configs .add (spec )
420- return list (op_configs )
421-
422- @classmethod
423- def get_supported_operator_for_quantization_config (
424- cls , quantization_config : Optional [QuantizationConfig ]
425- ) -> List [OperatorPatternType ]:
426- if quantization_config is None :
427- all_ops = []
428- for _ , ops in cls .supported_config_and_operators :
429- all_ops .extend (ops )
430- return all_ops
431-
432- for config , ops in cls .supported_config_and_operators :
433- # note: this assumes each entry in cls.supported_spec_and_operators
434- # corresponds to one spec, e.g. we don't have
435- # [(spec1, op_list1), (spec1, op_list2), (spec2, op_list3)]
436- # where the first and second entry have the same spec but did not
437- # merge the op list
438- if config == quantization_config :
439- return ops
440- return []
0 commit comments