1- import os
2- import sys
3- import numpy as np
4- import re
5- import abc
6- import subprocess
7- import json
81import argparse
9- import time
10- from PIL import Image
11-
12- import onnx
2+ import numpy as np
133import onnxruntime
14- from onnx import helper , TensorProto , numpy_helper
15- from onnxruntime .quantization import quantize_static , CalibrationDataReader , QuantFormat , QuantType
16-
17-
18- class ResNet50DataReader (CalibrationDataReader ):
19- def __init__ (self , calibration_image_folder , augmented_model_path = 'augmented_model.onnx' ):
20- self .image_folder = calibration_image_folder
21- self .augmented_model_path = augmented_model_path
22- self .preprocess_flag = True
23- self .enum_data_dicts = []
24- self .datasize = 0
25-
26- def get_next (self ):
27- if self .preprocess_flag :
28- self .preprocess_flag = False
29- session = onnxruntime .InferenceSession (self .augmented_model_path , None )
30- (_ , _ , height , width ) = session .get_inputs ()[0 ].shape
31- nhwc_data_list = preprocess_func (self .image_folder , height , width , size_limit = 0 )
32- input_name = session .get_inputs ()[0 ].name
33- self .datasize = len (nhwc_data_list )
34- self .enum_data_dicts = iter ([{input_name : nhwc_data } for nhwc_data in nhwc_data_list ])
35- return next (self .enum_data_dicts , None )
36-
37-
38- def preprocess_func (images_folder , height , width , size_limit = 0 ):
39- '''
40- Loads a batch of images and preprocess them
41- parameter images_folder: path to folder storing images
42- parameter height: image height in pixels
43- parameter width: image width in pixels
44- parameter size_limit: number of images to load. Default is 0 which means all images are picked.
45- return: list of matrices characterizing multiple images
46- '''
47- image_names = os .listdir (images_folder )
48- if size_limit > 0 and len (image_names ) >= size_limit :
49- batch_filenames = [image_names [i ] for i in range (size_limit )]
50- else :
51- batch_filenames = image_names
52- unconcatenated_batch_data = []
4+ import time
5+ from onnxruntime .quantization import QuantFormat , QuantType , quantize_static
536
54- for image_name in batch_filenames :
55- image_filepath = images_folder + '/' + image_name
56- pillow_img = Image .new ("RGB" , (width , height ))
57- pillow_img .paste (Image .open (image_filepath ).resize ((width , height )))
58- input_data = np .float32 (pillow_img ) - \
59- np .array ([123.68 , 116.78 , 103.94 ], dtype = np .float32 )
60- nhwc_data = np .expand_dims (input_data , axis = 0 )
61- nchw_data = nhwc_data .transpose (0 , 3 , 1 , 2 ) # ONNX Runtime standard
62- unconcatenated_batch_data .append (nchw_data )
63- batch_data = np .concatenate (np .expand_dims (unconcatenated_batch_data , axis = 0 ), axis = 0 )
64- return batch_data
7+ import resnet50_data_reader
658
669
6710def benchmark (model_path ):
@@ -87,11 +30,15 @@ def get_args():
8730 parser = argparse .ArgumentParser ()
8831 parser .add_argument ("--input_model" , required = True , help = "input model" )
8932 parser .add_argument ("--output_model" , required = True , help = "output model" )
90- parser .add_argument ("--calibrate_dataset" , default = "./test_images" , help = "calibration data set" )
91- parser .add_argument ("--quant_format" ,
92- default = QuantFormat .QDQ ,
93- type = QuantFormat .from_string ,
94- choices = list (QuantFormat ))
33+ parser .add_argument (
34+ "--calibrate_dataset" , default = "./test_images" , help = "calibration data set"
35+ )
36+ parser .add_argument (
37+ "--quant_format" ,
38+ default = QuantFormat .QDQ ,
39+ type = QuantFormat .from_string ,
40+ choices = list (QuantFormat ),
41+ )
9542 parser .add_argument ("--per_channel" , default = False , type = bool )
9643 args = parser .parse_args ()
9744 return args
@@ -102,21 +49,25 @@ def main():
10249 input_model_path = args .input_model
10350 output_model_path = args .output_model
10451 calibration_dataset_path = args .calibrate_dataset
105- dr = ResNet50DataReader (calibration_dataset_path )
106- quantize_static (input_model_path ,
107- output_model_path ,
108- dr ,
109- quant_format = args .quant_format ,
110- per_channel = args .per_channel ,
111- weight_type = QuantType .QInt8 )
112- print ('Calibrated and quantized model saved.' )
113-
114- print ('benchmarking fp32 model...' )
52+ dr = resnet50_data_reader .ResNet50DataReader (
53+ calibration_dataset_path , input_model_path
54+ )
55+ quantize_static (
56+ input_model_path ,
57+ output_model_path ,
58+ dr ,
59+ quant_format = args .quant_format ,
60+ per_channel = args .per_channel ,
61+ weight_type = QuantType .QInt8 ,
62+ )
63+ print ("Calibrated and quantized model saved." )
64+
65+ print ("benchmarking fp32 model..." )
11566 benchmark (input_model_path )
11667
117- print (' benchmarking int8 model...' )
68+ print (" benchmarking int8 model..." )
11869 benchmark (output_model_path )
11970
12071
121- if __name__ == ' __main__' :
72+ if __name__ == " __main__" :
12273 main ()
0 commit comments