Skip to content

Commit 0548aac

Browse files
authored
Merge pull request #15532 from hshen14/calibration_api_refine
Refine INT8 calibration API
2 parents 8e2dea5 + 2a82c56 commit 0548aac

File tree

2 files changed

+42
-27
lines changed

2 files changed

+42
-27
lines changed

python/paddle/fluid/contrib/int8_inference/utility.py

Lines changed: 29 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -32,10 +32,13 @@ class Calibrator(object):
3232

3333
def __init__(self, *args, **kwargs):
3434
self.program = kwargs['program']
35-
self.iterations = kwargs['iterations']
3635
self.pretrained_model = kwargs['pretrained_model']
37-
self.debug = kwargs['debug']
36+
self.debug = kwargs['debug'] if 'debug' in kwargs else False
3837
self.algo = kwargs['algo']
38+
self.output = kwargs['output']
39+
self.feed_var_names = kwargs['feed_var_names']
40+
self.fetch_list = kwargs['fetch_list']
41+
self.exe = kwargs['exe']
3942

4043
self._conv_input_var_name = []
4144
self._conv_output_var_name = []
@@ -54,17 +57,38 @@ def __init__(self, *args, **kwargs):
5457
self._u8_output_var = []
5558
self._s8_output_var = []
5659
self._persistable_vars = []
60+
self._sampling_data = {}
5761

58-
def generate_sampling_program(self):
5962
self.__init_analysis()
6063
self.__generate_output_program()
6164

62-
def generate_quantized_data(self, sampling_data):
63-
self.__sampling(sampling_data)
65+
def save_int8_model(self):
66+
self.__sampling(self._sampling_data)
6467
self.__save_scale()
6568
self.__update_program()
6669
self.__update_output_program_attr()
6770
self.__display_debug()
71+
self.__save_offline_model()
72+
73+
def sample_data(self):
74+
'''
75+
Sampling the tensor data of variable.
76+
'''
77+
for i in self.sampling_program.list_vars():
78+
if i.name in self.sampling_vars:
79+
np_data = np.array(fluid.global_scope().find_var(i.name)
80+
.get_tensor())
81+
if i.name not in self._sampling_data:
82+
self._sampling_data[i.name] = []
83+
self._sampling_data[i.name].append(np_data)
84+
85+
def __save_offline_model(self):
86+
'''
87+
Save the quantized model to the disk.
88+
'''
89+
fluid.io.save_inference_model(self.output, self.feed_var_names,
90+
self.fetch_list, self.exe,
91+
self.sampling_program)
6892

6993
def __display_debug(self):
7094
if self.debug:

python/paddle/fluid/contrib/tests/test_calibration.py

Lines changed: 13 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@
2626
from PIL import Image, ImageEnhance
2727
import math
2828
sys.path.append('..')
29-
import int8_inference.utility as ut
29+
import int8_inference.utility as int8_utility
3030

3131
random.seed(0)
3232
np.random.seed(0)
@@ -120,13 +120,13 @@ class TestCalibration(unittest.TestCase):
120120
def setUp(self):
121121
# TODO(guomingz): Put the download process in the cmake.
122122
# Download and unzip test data set
123-
imagenet_dl_url = 'http://paddle-inference-dist.bj.bcebos.com/int8/calibration_test_data.tar.gz'
123+
imagenet_dl_url = 'http://paddle-inference-dist.cdn.bcebos.com/int8/calibration_test_data.tar.gz'
124124
zip_file_name = imagenet_dl_url.split('/')[-1]
125125
cmd = 'rm -rf data {} && mkdir data && wget {} && tar xvf {} -C data'.format(
126126
zip_file_name, imagenet_dl_url, zip_file_name)
127127
os.system(cmd)
128128
# resnet50 fp32 data
129-
resnet50_fp32_model_url = 'http://paddle-inference-dist.bj.bcebos.com/int8/resnet50_int8_model.tar.gz'
129+
resnet50_fp32_model_url = 'http://paddle-inference-dist.cdn.bcebos.com/int8/resnet50_int8_model.tar.gz'
130130
resnet50_zip_name = resnet50_fp32_model_url.split('/')[-1]
131131
resnet50_unzip_folder_name = 'resnet50_fp32'
132132
cmd = 'rm -rf {} {} && mkdir {} && wget {} && tar xvf {} -C {}'.format(
@@ -135,8 +135,7 @@ def setUp(self):
135135
resnet50_zip_name, resnet50_unzip_folder_name)
136136
os.system(cmd)
137137

138-
self.iterations = 100
139-
self.skip_batch_num = 5
138+
self.iterations = 50
140139

141140
def run_program(self, model_path, generate_int8=False, algo='direct'):
142141
image_shape = [3, 224, 224]
@@ -163,16 +162,15 @@ def run_program(self, model_path, generate_int8=False, algo='direct'):
163162

164163
print("Start calibration ...")
165164

166-
calibrator = ut.Calibrator(
165+
calibrator = int8_utility.Calibrator(
167166
program=infer_program,
168167
pretrained_model=model_path,
169-
iterations=100,
170-
debug=False,
171-
algo=algo)
172-
173-
sampling_data = {}
168+
algo=algo,
169+
exe=exe,
170+
output=int8_model,
171+
feed_var_names=feed_dict,
172+
fetch_list=fetch_targets)
174173

175-
calibrator.generate_sampling_program()
176174
test_info = []
177175
cnt = 0
178176
for batch_id, data in enumerate(val_reader()):
@@ -192,13 +190,7 @@ def run_program(self, model_path, generate_int8=False, algo='direct'):
192190
feed_dict[1]: label},
193191
fetch_list=fetch_targets)
194192
if generate_int8:
195-
for i in calibrator.sampling_program.list_vars():
196-
if i.name in calibrator.sampling_vars:
197-
np_data = np.array(fluid.global_scope().find_var(i.name)
198-
.get_tensor())
199-
if i.name not in sampling_data:
200-
sampling_data[i.name] = []
201-
sampling_data[i.name].append(np_data)
193+
calibrator.sample_data()
202194

203195
test_info.append(np.mean(acc1) * len(data))
204196
cnt += len(data)
@@ -209,9 +201,8 @@ def run_program(self, model_path, generate_int8=False, algo='direct'):
209201
break
210202

211203
if generate_int8:
212-
calibrator.generate_quantized_data(sampling_data)
213-
fluid.io.save_inference_model(int8_model, feed_dict, fetch_targets,
214-
exe, calibrator.sampling_program)
204+
calibrator.save_int8_model()
205+
215206
print(
216207
"Calibration is done and the corresponding files were generated at {}".
217208
format(os.path.abspath("calibration_out")))

0 commit comments

Comments
 (0)