Skip to content

Commit e7a02b5

Browse files
authored
changed post-quant methods (PaddlePaddle#713)
1 parent 8fad8d4 commit e7a02b5

File tree

3 files changed

+24
-10
lines changed

3 files changed

+24
-10
lines changed

demo/quant/quant_post/README.md

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -43,7 +43,7 @@ python quant_post_static.py --model_path ./inference_model/MobileNet --save_path
4343

4444
运行以上命令后,可在``${save_path}``下看到量化后的模型文件和参数文件。
4545

46-
> 使用的量化算法为``'KL'``, 使用训练集中的160张图片进行量化参数的校正
46+
> 使用的量化算法为``'hist'``, 使用训练集中的32张图片进行量化参数的校正
4747
4848

4949
### 测试精度
@@ -67,6 +67,6 @@ python eval.py --model_path ./quant_model_train/MobileNet --model_name __model__
6767

6868
精度输出为
6969
```
70-
top1_acc/top5_acc= [0.70141864 0.89086477]
70+
top1_acc/top5_acc= [0.70328485 0.89183184]
7171
```
72-
从以上精度对比可以看出,对``mobilenet````imagenet``上的分类模型进行离线量化后 ``top1``精度损失为``0.77%````top5``精度损失为``0.46%``.
72+
从以上精度对比可以看出,对``mobilenet````imagenet``上的分类模型进行离线量化后 ``top1``精度损失为``0.59%````top5``精度损失为``0.36%``.

demo/quant/quant_post/quant_post.py

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -19,13 +19,15 @@
1919
parser = argparse.ArgumentParser(description=__doc__)
2020
add_arg = functools.partial(add_arguments, argparser=parser)
2121
# yapf: disable
22-
add_arg('batch_size', int, 16, "Minibatch size.")
23-
add_arg('batch_num', int, 10, "Batch number")
22+
add_arg('batch_size', int, 32, "Minibatch size.")
23+
add_arg('batch_num', int, 1, "Batch number")
2424
add_arg('use_gpu', bool, True, "Whether to use GPU or not.")
2525
add_arg('model_path', str, "./inference_model/MobileNet/", "model dir")
2626
add_arg('save_path', str, "./quant_model/MobileNet/", "model dir to save quanted model")
2727
add_arg('model_filename', str, None, "model file name")
2828
add_arg('params_filename', str, None, "params file name")
29+
add_arg('algo', str, 'hist', "calibration algorithm")
30+
add_arg('hist_percent', float, 0.9999, "The percentile of algo:hist")
2931
# yapf: enable
3032

3133

@@ -46,7 +48,9 @@ def quantize(args):
4648
model_filename=args.model_filename,
4749
params_filename=args.params_filename,
4850
batch_size=args.batch_size,
49-
batch_nums=args.batch_num)
51+
batch_nums=args.batch_num,
52+
algo=args.algo,
53+
hist_percent=args.hist_percent)
5054

5155

5256
def main():

paddleslim/quant/quanter.py

Lines changed: 14 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -313,7 +313,9 @@ def quant_post_static(
313313
batch_size=16,
314314
batch_nums=None,
315315
scope=None,
316-
algo='KL',
316+
algo='hist',
317+
hist_percent=0.9999,
318+
bias_correction=False,
317319
quantizable_op_type=["conv2d", "depthwise_conv2d", "mul"],
318320
is_full_quantize=False,
319321
weight_bits=8,
@@ -358,9 +360,15 @@ def quant_post_static(
358360
generated by sample_generator as calibrate data.
359361
scope(paddle.static.Scope, optional): The scope to run program, use it to load
360362
and save variables. If scope is None, will use paddle.static.global_scope().
361-
algo(str, optional): If algo=KL, use KL-divergenc method to
362-
get the more precise scale factor. If algo='direct', use
363-
abs_max method to get the scale factor. Default: 'KL'.
363+
algo(str, optional): If algo='KL', use KL-divergenc method to
364+
get the scale factor. If algo='hist', use the hist_percent of histogram
365+
to get the scale factor. If algo='mse', search for the best scale factor which
366+
makes the mse loss minimal. Use one batch of data for mse is enough. If
367+
algo='avg', use the average of abs_max values to get the scale factor. If
368+
algo='abs_max', use abs_max method to get the scale factor. Default: 'hist'.
369+
hist_percent(float, optional): The percentile of histogram for algo hist.Default:0.9999.
370+
bias_correction(bool, optional): Bias correction method of https://arxiv.org/abs/1810.05723.
371+
Default: False.
364372
quantizable_op_type(list[str], optional): The list of op types
365373
that will be quantized. Default: ["conv2d", "depthwise_conv2d",
366374
"mul"].
@@ -397,6 +405,8 @@ def quant_post_static(
397405
batch_nums=batch_nums,
398406
scope=scope,
399407
algo=algo,
408+
hist_percent=hist_percent,
409+
bias_correction=bias_correction,
400410
quantizable_op_type=quantizable_op_type,
401411
is_full_quantize=is_full_quantize,
402412
weight_bits=weight_bits,

0 commit comments

Comments
 (0)