Skip to content

Commit 3817862

Browse files
jerryzh168facebook-github-bot
authored andcommitted
Add verification for quant examples and update README.md to include quantization
Summary: att Reviewed By: guangy10 Differential Revision: D48337276 fbshipit-source-id: 5f716559d0774c69e77c70c778e213c7c8f02492
1 parent 6851268 commit 3817862

File tree

2 files changed

+96
-4
lines changed

2 files changed

+96
-4
lines changed

examples/README.md

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -51,6 +51,23 @@ Use `-h` (or `--help`) to see all the supported models.
5151
buck2 run examples/executor_runner:executor_runner -- --model_path mv2.pte
5252
```
5353

54+
## Quantization
55+
Here is the [Quantization Flow Docs](/docs/website/docs/tutorials/quantization_flow.md).
56+
57+
You can run quantization test with the following command:
58+
```bash
59+
buck2 run executorch/examples/quantization:example -- --model_name="mv2" # for MobileNetv2
60+
```
61+
It will print both the original model after capture and quantized model.
62+
63+
The flow produces a quantized model that could be lowered through partitioner or the runtime directly.
64+
65+
66+
you can also find the valid quantized example models by running:
67+
```bash
68+
buck2 run executorch/examples/quantization:example -- --help
69+
```
70+
5471
## Dependencies
5572

5673
Various models listed in this directory have dependencies on some other packages, e.g. torchvision, torchaudio.

examples/quantization/example.py

Lines changed: 79 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,18 @@
77
import argparse
88
import copy
99

10+
import torch
1011
import torch._export as export
12+
from torch.ao.ns.fx.utils import compute_sqnr
13+
from torch.ao.quantization import ( # @manual
14+
default_per_channel_symmetric_qnnpack_qconfig,
15+
QConfigMapping,
16+
)
17+
from torch.ao.quantization.backend_config import get_executorch_backend_config
18+
from torch.ao.quantization.quantize_fx import (
19+
_convert_to_reference_decomposed_fx,
20+
prepare_fx,
21+
)
1122
from torch.ao.quantization.quantize_pt2e import convert_pt2e, prepare_pt2e
1223
from torch.ao.quantization.quantizer import XNNPACKQuantizer
1324
from torch.ao.quantization.quantizer.xnnpack_quantizer import (
@@ -21,6 +32,7 @@
2132

2233

2334
def quantize(model_name, model, example_inputs):
35+
"""This is the official recommended flow for quantization in pytorch 2.0 export"""
2436
m = model.eval()
2537
m = export.capture_pre_autograd_graph(m, copy.deepcopy(example_inputs))
2638
print("original model:", m)
@@ -38,23 +50,86 @@ def quantize(model_name, model, example_inputs):
3850
# aten = export_to_ff(model_name, m, copy.deepcopy(example_inputs))
3951

4052

53+
def verify_xnnpack_quantizer_matching_fx_quant_model(model_name, model, example_inputs):
54+
"""This is a verification against fx graph mode quantization flow as a sanity check"""
55+
model.eval()
56+
m_copy = copy.deepcopy(model)
57+
m = model
58+
59+
# 1. pytorch 2.0 export quantization flow (recommended/default flow)
60+
m = export.capture_pre_autograd_graph(m, copy.deepcopy(example_inputs))
61+
quantizer = XNNPACKQuantizer()
62+
quantization_config = get_symmetric_quantization_config(is_per_channel=True)
63+
quantizer.set_global(quantization_config)
64+
m = prepare_pt2e(m, quantizer)
65+
# calibration
66+
after_prepare_result = m(*example_inputs)
67+
m = convert_pt2e(m)
68+
after_quant_result = m(*example_inputs)
69+
70+
# 2. the previous fx graph mode quantization reference flow
71+
qconfig = default_per_channel_symmetric_qnnpack_qconfig
72+
qconfig_mapping = QConfigMapping().set_global(qconfig)
73+
backend_config = get_executorch_backend_config()
74+
m_fx = prepare_fx(
75+
m_copy, qconfig_mapping, example_inputs, backend_config=backend_config
76+
)
77+
after_prepare_result_fx = m_fx(*example_inputs)
78+
m_fx = _convert_to_reference_decomposed_fx(m_fx, backend_config=backend_config)
79+
after_quant_result_fx = m_fx(*example_inputs)
80+
81+
# 3. compare results
82+
# NB: this check is more useful for QAT since for PTQ we are only inserting observers that does not change the
83+
# output of a model, so it's just testing the numerical difference for different captures in PTQ
84+
# for QAT it is also testing whether the fake quant placement match or not
85+
# not exactly the same due to capture changing numerics, but still really close
86+
print("m:", m)
87+
print("m_fx:", m_fx)
88+
print("prepare sqnr:", compute_sqnr(after_prepare_result, after_prepare_result_fx))
89+
assert compute_sqnr(after_prepare_result, after_prepare_result_fx) > 100
90+
print("quant diff max:", torch.max(after_quant_result - after_quant_result_fx))
91+
assert torch.max(after_quant_result - after_quant_result_fx) < 1e-1
92+
print("quant sqnr:", compute_sqnr(after_quant_result, after_quant_result_fx))
93+
assert compute_sqnr(after_quant_result, after_quant_result_fx) > 30
94+
95+
4196
if __name__ == "__main__":
97+
# Note: for mv3, the mul op is not supported in XNNPACKQuantizer, that could be supported soon
98+
QUANT_MODEL_NAME_TO_MODEL = {
99+
name: MODEL_NAME_TO_MODEL[name] for name in ["linear", "add", "add_mul", "mv2"]
100+
}
101+
42102
parser = argparse.ArgumentParser()
43103
parser.add_argument(
44104
"-m",
45105
"--model_name",
46106
required=True,
47-
help=f"Provide model name. Valid ones: {list(MODEL_NAME_TO_MODEL.keys())}",
107+
help=f"Provide model name. Valid ones: {list(QUANT_MODEL_NAME_TO_MODEL.keys())}",
108+
)
109+
parser.add_argument(
110+
"-ve",
111+
"--verify",
112+
action="store_true",
113+
required=False,
114+
default=False,
115+
help="flag for verifying XNNPACKQuantizer against fx graph mode quantization",
48116
)
49117

50118
args = parser.parse_args()
51119

52-
if args.model_name not in MODEL_NAME_TO_MODEL:
120+
if not args.verify and args.model_name not in QUANT_MODEL_NAME_TO_MODEL:
53121
raise RuntimeError(
54-
f"Model {args.model_name} is not a valid name. "
55-
f"Available models are {list(MODEL_NAME_TO_MODEL.keys())}."
122+
f"Model {args.model_name} is not a valid name. or not quantizable right now, "
123+
"please contact executorch team if you want to learn why or how to support "
124+
"quantization for the requested model"
125+
f"Available models are {list(QUANT_MODEL_NAME_TO_MODEL.keys())}."
56126
)
57127

58128
model, example_inputs = MODEL_NAME_TO_MODEL[args.model_name]()
59129

130+
if args.verify:
131+
verify_xnnpack_quantizer_matching_fx_quant_model(
132+
args.model_name, model, example_inputs
133+
)
134+
60135
quantize(args.model_name, model, example_inputs)

0 commit comments

Comments
 (0)