7
7
import argparse
8
8
import copy
9
9
10
+ import torch
10
11
import torch ._export as export
12
+ from torch .ao .ns .fx .utils import compute_sqnr
13
+ from torch .ao .quantization import ( # @manual
14
+ default_per_channel_symmetric_qnnpack_qconfig ,
15
+ QConfigMapping ,
16
+ )
17
+ from torch .ao .quantization .backend_config import get_executorch_backend_config
18
+ from torch .ao .quantization .quantize_fx import (
19
+ _convert_to_reference_decomposed_fx ,
20
+ prepare_fx ,
21
+ )
11
22
from torch .ao .quantization .quantize_pt2e import convert_pt2e , prepare_pt2e
12
23
from torch .ao .quantization .quantizer import XNNPACKQuantizer
13
24
from torch .ao .quantization .quantizer .xnnpack_quantizer import (
21
32
22
33
23
34
def quantize (model_name , model , example_inputs ):
35
+ """This is the official recommended flow for quantization in pytorch 2.0 export"""
24
36
m = model .eval ()
25
37
m = export .capture_pre_autograd_graph (m , copy .deepcopy (example_inputs ))
26
38
print ("original model:" , m )
@@ -38,23 +50,86 @@ def quantize(model_name, model, example_inputs):
38
50
# aten = export_to_ff(model_name, m, copy.deepcopy(example_inputs))
39
51
40
52
53
+ def verify_xnnpack_quantizer_matching_fx_quant_model (model_name , model , example_inputs ):
54
+ """This is a verification against fx graph mode quantization flow as a sanity check"""
55
+ model .eval ()
56
+ m_copy = copy .deepcopy (model )
57
+ m = model
58
+
59
+ # 1. pytorch 2.0 export quantization flow (recommended/default flow)
60
+ m = export .capture_pre_autograd_graph (m , copy .deepcopy (example_inputs ))
61
+ quantizer = XNNPACKQuantizer ()
62
+ quantization_config = get_symmetric_quantization_config (is_per_channel = True )
63
+ quantizer .set_global (quantization_config )
64
+ m = prepare_pt2e (m , quantizer )
65
+ # calibration
66
+ after_prepare_result = m (* example_inputs )
67
+ m = convert_pt2e (m )
68
+ after_quant_result = m (* example_inputs )
69
+
70
+ # 2. the previous fx graph mode quantization reference flow
71
+ qconfig = default_per_channel_symmetric_qnnpack_qconfig
72
+ qconfig_mapping = QConfigMapping ().set_global (qconfig )
73
+ backend_config = get_executorch_backend_config ()
74
+ m_fx = prepare_fx (
75
+ m_copy , qconfig_mapping , example_inputs , backend_config = backend_config
76
+ )
77
+ after_prepare_result_fx = m_fx (* example_inputs )
78
+ m_fx = _convert_to_reference_decomposed_fx (m_fx , backend_config = backend_config )
79
+ after_quant_result_fx = m_fx (* example_inputs )
80
+
81
+ # 3. compare results
82
+ # NB: this check is more useful for QAT since for PTQ we are only inserting observers that does not change the
83
+ # output of a model, so it's just testing the numerical difference for different captures in PTQ
84
+ # for QAT it is also testing whether the fake quant placement match or not
85
+ # not exactly the same due to capture changing numerics, but still really close
86
+ print ("m:" , m )
87
+ print ("m_fx:" , m_fx )
88
+ print ("prepare sqnr:" , compute_sqnr (after_prepare_result , after_prepare_result_fx ))
89
+ assert compute_sqnr (after_prepare_result , after_prepare_result_fx ) > 100
90
+ print ("quant diff max:" , torch .max (after_quant_result - after_quant_result_fx ))
91
+ assert torch .max (after_quant_result - after_quant_result_fx ) < 1e-1
92
+ print ("quant sqnr:" , compute_sqnr (after_quant_result , after_quant_result_fx ))
93
+ assert compute_sqnr (after_quant_result , after_quant_result_fx ) > 30
94
+
95
+
41
96
if __name__ == "__main__" :
97
+ # Note: for mv3, the mul op is not supported in XNNPACKQuantizer, that could be supported soon
98
+ QUANT_MODEL_NAME_TO_MODEL = {
99
+ name : MODEL_NAME_TO_MODEL [name ] for name in ["linear" , "add" , "add_mul" , "mv2" ]
100
+ }
101
+
42
102
parser = argparse .ArgumentParser ()
43
103
parser .add_argument (
44
104
"-m" ,
45
105
"--model_name" ,
46
106
required = True ,
47
- help = f"Provide model name. Valid ones: { list (MODEL_NAME_TO_MODEL .keys ())} " ,
107
+ help = f"Provide model name. Valid ones: { list (QUANT_MODEL_NAME_TO_MODEL .keys ())} " ,
108
+ )
109
+ parser .add_argument (
110
+ "-ve" ,
111
+ "--verify" ,
112
+ action = "store_true" ,
113
+ required = False ,
114
+ default = False ,
115
+ help = "flag for verifying XNNPACKQuantizer against fx graph mode quantization" ,
48
116
)
49
117
50
118
args = parser .parse_args ()
51
119
52
- if args .model_name not in MODEL_NAME_TO_MODEL :
120
+ if not args .verify and args . model_name not in QUANT_MODEL_NAME_TO_MODEL :
53
121
raise RuntimeError (
54
- f"Model { args .model_name } is not a valid name. "
55
- f"Available models are { list (MODEL_NAME_TO_MODEL .keys ())} ."
122
+ f"Model { args .model_name } is not a valid name. or not quantizable right now, "
123
+ "please contact executorch team if you want to learn why or how to support "
124
+ "quantization for the requested model"
125
+ f"Available models are { list (QUANT_MODEL_NAME_TO_MODEL .keys ())} ."
56
126
)
57
127
58
128
model , example_inputs = MODEL_NAME_TO_MODEL [args .model_name ]()
59
129
130
+ if args .verify :
131
+ verify_xnnpack_quantizer_matching_fx_quant_model (
132
+ args .model_name , model , example_inputs
133
+ )
134
+
60
135
quantize (args .model_name , model , example_inputs )
0 commit comments