26
26
from PIL import Image , ImageEnhance
27
27
import math
28
28
sys .path .append ('..' )
29
- import int8_inference .utility as ut
29
+ import int8_inference .utility as int8_utility
30
30
31
31
random .seed (0 )
32
32
np .random .seed (0 )
@@ -120,13 +120,13 @@ class TestCalibration(unittest.TestCase):
120
120
def setUp (self ):
121
121
# TODO(guomingz): Put the download process in the cmake.
122
122
# Download and unzip test data set
123
- imagenet_dl_url = 'http://paddle-inference-dist.bj .bcebos.com/int8/calibration_test_data.tar.gz'
123
+ imagenet_dl_url = 'http://paddle-inference-dist.cdn .bcebos.com/int8/calibration_test_data.tar.gz'
124
124
zip_file_name = imagenet_dl_url .split ('/' )[- 1 ]
125
125
cmd = 'rm -rf data {} && mkdir data && wget {} && tar xvf {} -C data' .format (
126
126
zip_file_name , imagenet_dl_url , zip_file_name )
127
127
os .system (cmd )
128
128
# resnet50 fp32 data
129
- resnet50_fp32_model_url = 'http://paddle-inference-dist.bj .bcebos.com/int8/resnet50_int8_model.tar.gz'
129
+ resnet50_fp32_model_url = 'http://paddle-inference-dist.cdn .bcebos.com/int8/resnet50_int8_model.tar.gz'
130
130
resnet50_zip_name = resnet50_fp32_model_url .split ('/' )[- 1 ]
131
131
resnet50_unzip_folder_name = 'resnet50_fp32'
132
132
cmd = 'rm -rf {} {} && mkdir {} && wget {} && tar xvf {} -C {}' .format (
@@ -135,8 +135,7 @@ def setUp(self):
135
135
resnet50_zip_name , resnet50_unzip_folder_name )
136
136
os .system (cmd )
137
137
138
- self .iterations = 100
139
- self .skip_batch_num = 5
138
+ self .iterations = 50
140
139
141
140
def run_program (self , model_path , generate_int8 = False , algo = 'direct' ):
142
141
image_shape = [3 , 224 , 224 ]
@@ -163,16 +162,15 @@ def run_program(self, model_path, generate_int8=False, algo='direct'):
163
162
164
163
print ("Start calibration ..." )
165
164
166
- calibrator = ut .Calibrator (
165
+ calibrator = int8_utility .Calibrator (
167
166
program = infer_program ,
168
167
pretrained_model = model_path ,
169
- iterations = 100 ,
170
- debug = False ,
171
- algo = algo )
172
-
173
- sampling_data = {}
168
+ algo = algo ,
169
+ exe = exe ,
170
+ output = int8_model ,
171
+ feed_var_names = feed_dict ,
172
+ fetch_list = fetch_targets )
174
173
175
- calibrator .generate_sampling_program ()
176
174
test_info = []
177
175
cnt = 0
178
176
for batch_id , data in enumerate (val_reader ()):
@@ -192,13 +190,7 @@ def run_program(self, model_path, generate_int8=False, algo='direct'):
192
190
feed_dict [1 ]: label },
193
191
fetch_list = fetch_targets )
194
192
if generate_int8 :
195
- for i in calibrator .sampling_program .list_vars ():
196
- if i .name in calibrator .sampling_vars :
197
- np_data = np .array (fluid .global_scope ().find_var (i .name )
198
- .get_tensor ())
199
- if i .name not in sampling_data :
200
- sampling_data [i .name ] = []
201
- sampling_data [i .name ].append (np_data )
193
+ calibrator .sample_data ()
202
194
203
195
test_info .append (np .mean (acc1 ) * len (data ))
204
196
cnt += len (data )
@@ -209,9 +201,8 @@ def run_program(self, model_path, generate_int8=False, algo='direct'):
209
201
break
210
202
211
203
if generate_int8 :
212
- calibrator .generate_quantized_data (sampling_data )
213
- fluid .io .save_inference_model (int8_model , feed_dict , fetch_targets ,
214
- exe , calibrator .sampling_program )
204
+ calibrator .save_int8_model ()
205
+
215
206
print (
216
207
"Calibration is done and the corresponding files were generated at {}" .
217
208
format (os .path .abspath ("calibration_out" )))
0 commit comments