Skip to content

Commit 0504e80

Browse files
chenfucnshamaksx
authored andcommitted
add qdq debugging example (microsoft#134)
Adding example run_qdq_debug.py
1 parent 60e3f5b commit 0504e80

File tree

5 files changed

+195
-19
lines changed

5 files changed

+195
-19
lines changed
Lines changed: 82 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,82 @@
1+
# ONNX Runtime Quantization Example
2+
3+
This folder contains example code for quantizing Resnet50 or MobilenetV2 models. The example has
4+
three parts:
5+
6+
1. Pre-processing
7+
2. Quantization
8+
3. Debugging
9+
10+
## Pre-processing
11+
12+
Pre-processing prepares a float32 model for quantization. Run the following command to pre-process
13+
model `mobilenetv2-7.onnx`.
14+
15+
```console
16+
python -m onnxruntime.quantization.shape_inference --input mobilenetv2-7.onnx --output mobilenetv2-7-infer.onnx
17+
```
18+
19+
The pre-processing consists of the following optional steps
20+
- Symbolic Shape Inference. It works best with transformer models.
21+
- ONNX Runtime Model Optimization.
22+
- ONNX Shape Inference.
23+
24+
Quantization requires tensor shape information to perform its best. Model optimization
25+
also improve the performance of quantization. For instance, a Convolution node followed
26+
by a BatchNormalization node can be merged into a single node during optimization.
27+
Currently we can not quantize BatchNormalization by itself, but we can quantize the
28+
merged Convolution + BatchNormalization node.
29+
30+
It is highly recommended to run model optimization in pre-processing instead of in quantization.
31+
To learn more about each of these steps and finer controls, run:
32+
```console
33+
python -m onnxruntime.quantization.shape_inference --help
34+
```
35+
36+
## Quantization
37+
38+
Quantization tool takes the pre-processed float32 model and produce a quantized model.
39+
It's recommended to use Tensor-oriented quantization (QDQ; Quantize and DeQuantize).
40+
41+
```console
42+
python run.py --input_model mobilenetv2-7-infer.onnx --output_model mobilenetv2-7.quant.onnx --calibrate_dataset ./test_images/
43+
```
44+
This will generate quantized model mobilenetv2-7.quant.onnx
45+
46+
The code in `run.py` creates an input data reader for the model, uses these input data to run
47+
the model to calibrate quantization parameters for each tensor, and then produces quantized
48+
model. Last, it runs the quantized model. Of these step, the only part that is specific to
49+
the model is the input data reader, as each model requires different shapes of input data.
50+
All other code can be easily generalized for other models.
51+
52+
For historical reasons, the quantization API performs model optimization by default.
53+
It's highly recommended to turn off model optimization using parameter
54+
`optimize_model=False`. This way, it is easier for the quantization debugger to match
55+
tensors of the float32 model and its quantized model, facilitating the triaging of quantization
56+
loss.
57+
58+
## Debugging
59+
60+
Quantization is not a loss-less process. Sometime it results in significant loss in accuracy.
61+
To help locate the source of these losses, our quantization debugging tool matches up
62+
weight tensors of the float32 model vs those of the quantized model. If a input data reader
63+
is provided, our debugger can also run both models with the same input and compare their
64+
corresponding tensors:
65+
66+
'''console
67+
python run_qdq_debug.py --float_model mobilenetv2-7-infer.onnx --qdq_model mobilenetv2-7.quant.onnx --calibrate_dataset ./test_images/
68+
'''
69+
70+
If you have quantized a model with optimization turned on, and found the debugging tool can not
71+
match certain float32 model tensors with their quantized counterparts, you can try to run the
72+
pre-processor to produce the optimized model, then compare the optimized model with the quantized model.
73+
74+
For instance, you have a model `abc_float32_model.onnx`, and a quantized model
75+
`abc_quantized.onnx`. During quantization process, you had optimization turned on
76+
by default. You can run the following code to produce an optimized float32 model:
77+
78+
```console
79+
python -m onnxruntime.quantization.shape_inference --input abc_float32_model.onnx --output abc_optimized.onnx --skip_symbolic_shape True
80+
```
81+
82+
Then run the debugger comparing `abc_optimized.onnx` with `abc_quantized.onnx`.

quantization/image_classification/cpu/ReadMe.txt

Lines changed: 0 additions & 2 deletions
This file was deleted.

quantization/image_classification/cpu/resnet50_data_reader.py

Lines changed: 19 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -39,23 +39,25 @@ def _preprocess_images(images_folder: str, height: int, width: int, size_limit=0
3939

4040
class ResNet50DataReader(CalibrationDataReader):
4141
def __init__(self, calibration_image_folder: str, model_path: str):
42-
self.image_folder = calibration_image_folder
43-
self.model_path = model_path
44-
self.preprocess_flag = True
45-
self.enum_data_dicts = []
46-
self.datasize = 0
42+
self.enum_data = None
43+
44+
# Use inference session to get input shape.
45+
session = onnxruntime.InferenceSession(model_path, None)
46+
(_, _, height, width) = session.get_inputs()[0].shape
47+
48+
# Convert image to input data
49+
self.nhwc_data_list = _preprocess_images(
50+
calibration_image_folder, height, width, size_limit=0
51+
)
52+
self.input_name = session.get_inputs()[0].name
53+
self.datasize = len(self.nhwc_data_list)
4754

4855
def get_next(self):
49-
if self.preprocess_flag:
50-
self.preprocess_flag = False
51-
session = onnxruntime.InferenceSession(self.model_path, None)
52-
(_, _, height, width) = session.get_inputs()[0].shape
53-
nhwc_data_list = _preprocess_images(
54-
self.image_folder, height, width, size_limit=0
56+
if self.enum_data is None:
57+
self.enum_data = iter(
58+
[{self.input_name: nhwc_data} for nhwc_data in self.nhwc_data_list]
5559
)
56-
input_name = session.get_inputs()[0].name
57-
self.datasize = len(nhwc_data_list)
58-
self.enum_data_dicts = iter(
59-
[{input_name: nhwc_data} for nhwc_data in nhwc_data_list]
60-
)
61-
return next(self.enum_data_dicts, None)
60+
return next(self.enum_data, None)
61+
62+
def rewind(self):
63+
self.enum_data = None

quantization/image_classification/cpu/run.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -52,13 +52,17 @@ def main():
5252
dr = resnet50_data_reader.ResNet50DataReader(
5353
calibration_dataset_path, input_model_path
5454
)
55+
56+
# Calibrate and quantize model
57+
# Turn off model optimization during quantization
5558
quantize_static(
5659
input_model_path,
5760
output_model_path,
5861
dr,
5962
quant_format=args.quant_format,
6063
per_channel=args.per_channel,
6164
weight_type=QuantType.QInt8,
65+
optimize_model=False,
6266
)
6367
print("Calibrated and quantized model saved.")
6468

Lines changed: 90 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,90 @@
1+
import argparse
2+
import onnx
3+
from onnxruntime.quantization.qdq_loss_debug import (
4+
collect_activations, compute_activation_error, compute_weight_error,
5+
create_activation_matching, create_weight_matching,
6+
modify_model_output_intermediate_tensors)
7+
8+
import resnet50_data_reader
9+
10+
11+
def get_args():
12+
parser = argparse.ArgumentParser()
13+
parser.add_argument(
14+
"--float_model", required=True, help="Path to original floating point model"
15+
)
16+
parser.add_argument("--qdq_model", required=True, help="Path to qdq model")
17+
parser.add_argument(
18+
"--calibrate_dataset", default="./test_images", help="calibration data set"
19+
)
20+
args = parser.parse_args()
21+
return args
22+
23+
24+
def _generate_aug_model_path(model_path: str) -> str:
25+
aug_model_path = (
26+
model_path[: -len(".onnx")] if model_path.endswith(".onnx") else model_path
27+
)
28+
return aug_model_path + ".save_tensors.onnx"
29+
30+
31+
def main():
32+
# Process input parameters and setup model input data reader
33+
args = get_args()
34+
float_model_path = args.float_model
35+
qdq_model_path = args.qdq_model
36+
calibration_dataset_path = args.calibrate_dataset
37+
38+
print("------------------------------------------------\n")
39+
print("Comparing weights of float model vs qdq model.....")
40+
41+
matched_weights = create_weight_matching(float_model_path, qdq_model_path)
42+
weights_error = compute_weight_error(matched_weights)
43+
for weight_name, err in weights_error.items():
44+
print(f"Cross model error of '{weight_name}': {err}\n")
45+
46+
print("------------------------------------------------\n")
47+
print("Augmenting models to save intermediate activations......")
48+
49+
aug_float_model = modify_model_output_intermediate_tensors(float_model_path)
50+
aug_float_model_path = _generate_aug_model_path(float_model_path)
51+
onnx.save(
52+
aug_float_model,
53+
aug_float_model_path,
54+
save_as_external_data=False,
55+
)
56+
del aug_float_model
57+
58+
aug_qdq_model = modify_model_output_intermediate_tensors(qdq_model_path)
59+
aug_qdq_model_path = _generate_aug_model_path(qdq_model_path)
60+
onnx.save(
61+
aug_qdq_model,
62+
aug_qdq_model_path,
63+
save_as_external_data=False,
64+
)
65+
del aug_qdq_model
66+
67+
print("------------------------------------------------\n")
68+
print("Running the augmented floating point model to collect activations......")
69+
input_data_reader = resnet50_data_reader.ResNet50DataReader(
70+
calibration_dataset_path, float_model_path
71+
)
72+
float_activations = collect_activations(aug_float_model_path, input_data_reader)
73+
74+
print("------------------------------------------------\n")
75+
print("Running the augmented qdq model to collect activations......")
76+
input_data_reader.rewind()
77+
qdq_activations = collect_activations(aug_qdq_model_path, input_data_reader)
78+
79+
print("------------------------------------------------\n")
80+
print("Comparing activations of float model vs qdq model......")
81+
82+
act_matching = create_activation_matching(qdq_activations, float_activations)
83+
act_error = compute_activation_error(act_matching)
84+
for act_name, err in act_error.items():
85+
print(f"Cross model error of '{act_name}': {err['xmodel_err']} \n")
86+
print(f"QDQ error of '{act_name}': {err['qdq_err']} \n")
87+
88+
89+
if __name__ == "__main__":
90+
main()

0 commit comments

Comments
 (0)