Skip to content

Commit 811eec3

Browse files
chenfucnshamaksx
authored andcommitted
Quantization tool example bug fix (microsoft#133)
In ResNet50DataReader, it uses an onnx session to obtain the model input shape. However it passes a madeup model name to the onnx session, resulting in file not found error. This change provide the original float model path to the data reader
1 parent ea0f813 commit 811eec3

File tree

2 files changed

+90
-78
lines changed

2 files changed

+90
-78
lines changed
Lines changed: 61 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,61 @@
1+
import numpy
2+
import onnxruntime
3+
import os
4+
from onnxruntime.quantization import CalibrationDataReader
5+
from PIL import Image
6+
7+
8+
def _preprocess_images(images_folder: str, height: int, width: int, size_limit=0):
9+
"""
10+
Loads a batch of images and preprocess them
11+
parameter images_folder: path to folder storing images
12+
parameter height: image height in pixels
13+
parameter width: image width in pixels
14+
parameter size_limit: number of images to load. Default is 0 which means all images are picked.
15+
return: list of matrices characterizing multiple images
16+
"""
17+
image_names = os.listdir(images_folder)
18+
if size_limit > 0 and len(image_names) >= size_limit:
19+
batch_filenames = [image_names[i] for i in range(size_limit)]
20+
else:
21+
batch_filenames = image_names
22+
unconcatenated_batch_data = []
23+
24+
for image_name in batch_filenames:
25+
image_filepath = images_folder + "/" + image_name
26+
pillow_img = Image.new("RGB", (width, height))
27+
pillow_img.paste(Image.open(image_filepath).resize((width, height)))
28+
input_data = numpy.float32(pillow_img) - numpy.array(
29+
[123.68, 116.78, 103.94], dtype=numpy.float32
30+
)
31+
nhwc_data = numpy.expand_dims(input_data, axis=0)
32+
nchw_data = nhwc_data.transpose(0, 3, 1, 2) # ONNX Runtime standard
33+
unconcatenated_batch_data.append(nchw_data)
34+
batch_data = numpy.concatenate(
35+
numpy.expand_dims(unconcatenated_batch_data, axis=0), axis=0
36+
)
37+
return batch_data
38+
39+
40+
class ResNet50DataReader(CalibrationDataReader):
41+
def __init__(self, calibration_image_folder: str, model_path: str):
42+
self.image_folder = calibration_image_folder
43+
self.model_path = model_path
44+
self.preprocess_flag = True
45+
self.enum_data_dicts = []
46+
self.datasize = 0
47+
48+
def get_next(self):
49+
if self.preprocess_flag:
50+
self.preprocess_flag = False
51+
session = onnxruntime.InferenceSession(self.model_path, None)
52+
(_, _, height, width) = session.get_inputs()[0].shape
53+
nhwc_data_list = _preprocess_images(
54+
self.image_folder, height, width, size_limit=0
55+
)
56+
input_name = session.get_inputs()[0].name
57+
self.datasize = len(nhwc_data_list)
58+
self.enum_data_dicts = iter(
59+
[{input_name: nhwc_data} for nhwc_data in nhwc_data_list]
60+
)
61+
return next(self.enum_data_dicts, None)
Lines changed: 29 additions & 78 deletions
Original file line numberDiff line numberDiff line change
@@ -1,67 +1,10 @@
1-
import os
2-
import sys
3-
import numpy as np
4-
import re
5-
import abc
6-
import subprocess
7-
import json
81
import argparse
9-
import time
10-
from PIL import Image
11-
12-
import onnx
2+
import numpy as np
133
import onnxruntime
14-
from onnx import helper, TensorProto, numpy_helper
15-
from onnxruntime.quantization import quantize_static, CalibrationDataReader, QuantFormat, QuantType
16-
17-
18-
class ResNet50DataReader(CalibrationDataReader):
19-
def __init__(self, calibration_image_folder, augmented_model_path='augmented_model.onnx'):
20-
self.image_folder = calibration_image_folder
21-
self.augmented_model_path = augmented_model_path
22-
self.preprocess_flag = True
23-
self.enum_data_dicts = []
24-
self.datasize = 0
25-
26-
def get_next(self):
27-
if self.preprocess_flag:
28-
self.preprocess_flag = False
29-
session = onnxruntime.InferenceSession(self.augmented_model_path, None)
30-
(_, _, height, width) = session.get_inputs()[0].shape
31-
nhwc_data_list = preprocess_func(self.image_folder, height, width, size_limit=0)
32-
input_name = session.get_inputs()[0].name
33-
self.datasize = len(nhwc_data_list)
34-
self.enum_data_dicts = iter([{input_name: nhwc_data} for nhwc_data in nhwc_data_list])
35-
return next(self.enum_data_dicts, None)
36-
37-
38-
def preprocess_func(images_folder, height, width, size_limit=0):
39-
'''
40-
Loads a batch of images and preprocess them
41-
parameter images_folder: path to folder storing images
42-
parameter height: image height in pixels
43-
parameter width: image width in pixels
44-
parameter size_limit: number of images to load. Default is 0 which means all images are picked.
45-
return: list of matrices characterizing multiple images
46-
'''
47-
image_names = os.listdir(images_folder)
48-
if size_limit > 0 and len(image_names) >= size_limit:
49-
batch_filenames = [image_names[i] for i in range(size_limit)]
50-
else:
51-
batch_filenames = image_names
52-
unconcatenated_batch_data = []
4+
import time
5+
from onnxruntime.quantization import QuantFormat, QuantType, quantize_static
536

54-
for image_name in batch_filenames:
55-
image_filepath = images_folder + '/' + image_name
56-
pillow_img = Image.new("RGB", (width, height))
57-
pillow_img.paste(Image.open(image_filepath).resize((width, height)))
58-
input_data = np.float32(pillow_img) - \
59-
np.array([123.68, 116.78, 103.94], dtype=np.float32)
60-
nhwc_data = np.expand_dims(input_data, axis=0)
61-
nchw_data = nhwc_data.transpose(0, 3, 1, 2) # ONNX Runtime standard
62-
unconcatenated_batch_data.append(nchw_data)
63-
batch_data = np.concatenate(np.expand_dims(unconcatenated_batch_data, axis=0), axis=0)
64-
return batch_data
7+
import resnet50_data_reader
658

669

6710
def benchmark(model_path):
@@ -87,11 +30,15 @@ def get_args():
8730
parser = argparse.ArgumentParser()
8831
parser.add_argument("--input_model", required=True, help="input model")
8932
parser.add_argument("--output_model", required=True, help="output model")
90-
parser.add_argument("--calibrate_dataset", default="./test_images", help="calibration data set")
91-
parser.add_argument("--quant_format",
92-
default=QuantFormat.QDQ,
93-
type=QuantFormat.from_string,
94-
choices=list(QuantFormat))
33+
parser.add_argument(
34+
"--calibrate_dataset", default="./test_images", help="calibration data set"
35+
)
36+
parser.add_argument(
37+
"--quant_format",
38+
default=QuantFormat.QDQ,
39+
type=QuantFormat.from_string,
40+
choices=list(QuantFormat),
41+
)
9542
parser.add_argument("--per_channel", default=False, type=bool)
9643
args = parser.parse_args()
9744
return args
@@ -102,21 +49,25 @@ def main():
10249
input_model_path = args.input_model
10350
output_model_path = args.output_model
10451
calibration_dataset_path = args.calibrate_dataset
105-
dr = ResNet50DataReader(calibration_dataset_path)
106-
quantize_static(input_model_path,
107-
output_model_path,
108-
dr,
109-
quant_format=args.quant_format,
110-
per_channel=args.per_channel,
111-
weight_type=QuantType.QInt8)
112-
print('Calibrated and quantized model saved.')
113-
114-
print('benchmarking fp32 model...')
52+
dr = resnet50_data_reader.ResNet50DataReader(
53+
calibration_dataset_path, input_model_path
54+
)
55+
quantize_static(
56+
input_model_path,
57+
output_model_path,
58+
dr,
59+
quant_format=args.quant_format,
60+
per_channel=args.per_channel,
61+
weight_type=QuantType.QInt8,
62+
)
63+
print("Calibrated and quantized model saved.")
64+
65+
print("benchmarking fp32 model...")
11566
benchmark(input_model_path)
11667

117-
print('benchmarking int8 model...')
68+
print("benchmarking int8 model...")
11869
benchmark(output_model_path)
11970

12071

121-
if __name__ == '__main__':
72+
if __name__ == "__main__":
12273
main()

0 commit comments

Comments
 (0)