Skip to content

Commit 926227c

Browse files
authored
Post_training_quantization support set quant 8/16 bits (#22492) (#22577)
Post_training_quantization support set quant 8/16 bits
1 parent baec7a3 commit 926227c

File tree

2 files changed

+37
-28
lines changed

2 files changed

+37
-28
lines changed

python/paddle/fluid/contrib/slim/quantization/post_training_quantization.py

Lines changed: 29 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -54,17 +54,19 @@ def _set_variable_data(scope, place, var_name, np_value):
5454

5555
class PostTrainingQuantization(object):
5656
def __init__(self,
57-
executor,
58-
sample_generator,
59-
model_dir,
57+
executor=None,
58+
scope=None,
59+
model_dir=None,
6060
model_filename=None,
6161
params_filename=None,
62+
sample_generator=None,
6263
batch_size=10,
6364
batch_nums=None,
64-
scope=None,
6565
algo="KL",
6666
quantizable_op_type=["conv2d", "depthwise_conv2d", "mul"],
6767
is_full_quantize=False,
68+
weight_bits=8,
69+
activation_bits=8,
6870
is_use_cache_file=False,
6971
cache_dir="./temp_post_training"):
7072
'''
@@ -76,9 +78,8 @@ def __init__(self,
7678
Args:
7779
executor(fluid.Executor): The executor to load, run and save the
7880
quantized model.
79-
sample_generator(Python Generator): The sample generator provides
80-
calibrate data for DataLoader, and it only returns a sample every
81-
time.
81+
scope(fluid.Scope, optional): The scope of the program, use it to load
82+
and save variables. If scope=None, get scope by global_scope().
8283
model_dir(str): The path of the fp32 model that will be quantized,
8384
and the model and params files are under the path.
8485
model_filename(str, optional): The name of file to load the inference
@@ -88,12 +89,13 @@ def __init__(self,
8889
When all parameters were saved in a single binary file, set it
8990
as the real filename. If parameters were saved in separate files,
9091
set it as 'None'. Default is 'None'.
92+
sample_generator(Python Generator): The sample generator provides
93+
calibrate data for DataLoader, and it only returns a sample every
94+
time.
9195
batch_size(int, optional): The batch size of DataLoader. Default is 10.
9296
batch_nums(int, optional): If batch_nums is not None, the number of
9397
calibrate data is batch_size*batch_nums. If batch_nums is None, use
9498
all data provided by sample_generator as calibrate data.
95-
scope(fluid.Scope, optional): The scope of the program, use it to load
96-
and save variables. If scope=None, get scope by global_scope().
9799
algo(str, optional): If algo=KL, use KL-divergenc method to
98100
get the more precise scale factor. If algo='direct', use
99101
abs_max methon to get the scale factor. Default is KL.
@@ -104,6 +106,8 @@ def __init__(self,
104106
apply quantization to all supported quantizable op type. If set
105107
is_full_quantized as False, only apply quantization to the op type
106108
according to the input quantizable_op_type.
109+
weight_bits(int, optional): quantization bit number for weights.
110+
activation_bits(int): quantization bit number for activation.
107111
is_use_cache_file(bool, optional): If set is_use_cache_file as False,
108112
all temp data will be saved in memory. If set is_use_cache_file as True,
109113
it will save temp data to disk. When the fp32 model is complex or
@@ -150,14 +154,20 @@ def __init__(self,
150154
ptq.quantize()
151155
ptq.save_quantized_model(save_model_path)
152156
'''
157+
158+
assert executor is not None, "The executor cannot be None."
159+
assert model_dir is not None, "The model_dir cannot be None."
160+
assert sample_generator is not None, \
161+
"The sample_generator cannot be None."
162+
153163
self._executor = executor
154-
self._sample_generator = sample_generator
164+
self._scope = global_scope() if scope == None else scope
155165
self._model_dir = model_dir
156166
self._model_filename = model_filename
157167
self._params_filename = params_filename
168+
self._sample_generator = sample_generator
158169
self._batch_size = batch_size
159170
self._batch_nums = batch_nums
160-
self._scope = global_scope() if scope == None else scope
161171
self._algo = algo
162172
self._is_use_cache_file = is_use_cache_file
163173
self._cache_dir = cache_dir
@@ -604,7 +614,7 @@ def quantize_weight_to_int(self,
604614
save_model_filename=None,
605615
save_params_filename=None,
606616
quantizable_op_type=["conv2d", "mul"],
607-
quantize_weight_bits=8,
617+
weight_bits=8,
608618
threshold_rate=0.0):
609619
'''
610620
In order to reduce the size of model, this api quantizes the weight
@@ -624,8 +634,8 @@ def quantize_weight_to_int(self,
624634
that will be quantized, and the quantized ops should be
625635
contained in ["conv2d", "depthwise_conv2d", "mul"].
626636
Default is ["conv2d","mul"].
627-
quantize_weight_bits(int, optional): The bits for the quantized
628-
weight, and it should be 8 or 16. Default is 8.
637+
weight_bits(int, optional): The bits for the quantized weight,
638+
and it should be 8 or 16. Default is 8.
629639
threshold_rate(float, optional): This api uses abs_max methd to
630640
quantize the weight from float32 to int8/16, and the abs max
631641
value is important for quantization diff. When the abs_max
@@ -637,10 +647,10 @@ def quantize_weight_to_int(self,
637647
assert op_type in self._supported_quantizable_op_type, \
638648
"input error:" + op_type + \
639649
" is not supported for weight quantization."
640-
assert quantize_weight_bits in [8, 16], \
641-
"input error: quantize_weight_bits should be 8 or 16."
642-
quantize_range = (1 << (quantize_weight_bits - 1)) - 1
643-
save_weight_dtype = np.int8 if quantize_weight_bits == 8 else np.int16
650+
assert weight_bits in [8, 16], \
651+
"input error: weight_bits should be 8 or 16."
652+
quantize_range = (1 << (weight_bits - 1)) - 1
653+
save_weight_dtype = np.int8 if weight_bits == 8 else np.int16
644654

645655
place = core.CPUPlace()
646656
exe = Executor(place)
@@ -677,8 +687,7 @@ def quantize_weight_to_int(self,
677687
_set_variable_data(scope, place, var_name,
678688
quantized_var_tensor_data)
679689
op._set_attr(var_name + "_quant_scale", [scale])
680-
op._set_attr('quantize_weight_bits',
681-
quantize_weight_bits)
690+
op._set_attr('quantize_weight_bits', weight_bits)
682691

683692
io.save_inference_model(
684693
dirname=save_model_dir,

python/paddle/fluid/contrib/slim/tests/test_weight_quantization_mobilenetv1.py

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -42,20 +42,20 @@ def cache_unzipping(self, target_folder, zip_path):
4242
zip_path)
4343
os.system(cmd)
4444

45-
def run_test(self, model_name, model_data_url, model_data_md5,
46-
quantize_weight_bits, quantizable_op_type, threshold_rate):
45+
def run_test(self, model_name, model_data_url, model_data_md5, weight_bits,
46+
quantizable_op_type, threshold_rate):
4747

4848
model_dir = self.download_model(model_name, model_data_url,
4949
model_data_md5)
5050

5151
timestamp = time.strftime('%Y-%m-%d-%H-%M-%S', time.localtime())
5252
save_model_dir = os.path.join(
5353
os.getcwd(),
54-
model_name + "_wq_" + str(quantize_weight_bits) + "_" + timestamp)
54+
model_name + "_wq_" + str(weight_bits) + "_" + timestamp)
5555
weight_quant = WeightQuantization(model_dir=model_dir + "/model")
5656
weight_quant.quantize_weight_to_int(
5757
save_model_dir=save_model_dir,
58-
quantize_weight_bits=quantize_weight_bits,
58+
weight_bits=weight_bits,
5959
quantizable_op_type=quantizable_op_type,
6060
threshold_rate=threshold_rate)
6161
print("finish weight quantization for " + model_name + "\n")
@@ -73,18 +73,18 @@ class TestWeightQuantizationMobilenetv1(TestWeightQuantization):
7373
model_data_md5 = "13892b0716d26443a8cdea15b3c6438b"
7474

7575
def test_weight_quantization_mobilenetv1_8bit(self):
76-
quantize_weight_bits = 8
76+
weight_bits = 8
7777
quantizable_op_type = ['conv2d', 'depthwise_conv2d', 'mul']
7878
threshold_rate = 0.0
7979
self.run_test(self.model_name, self.model_data_url, self.model_data_md5,
80-
quantize_weight_bits, quantizable_op_type, threshold_rate)
80+
weight_bits, quantizable_op_type, threshold_rate)
8181

8282
def test_weight_quantization_mobilenetv1_16bit(self):
83-
quantize_weight_bits = 16
83+
weight_bits = 16
8484
quantizable_op_type = ['conv2d', 'depthwise_conv2d', 'mul']
8585
threshold_rate = 1e-9
8686
self.run_test(self.model_name, self.model_data_url, self.model_data_md5,
87-
quantize_weight_bits, quantizable_op_type, threshold_rate)
87+
weight_bits, quantizable_op_type, threshold_rate)
8888

8989

9090
if __name__ == '__main__':

0 commit comments

Comments
 (0)