Skip to content

Commit d31a202

Browse files
authored
add adaround post quant method (PaddlePaddle#1023)
1 parent d2ecc86 commit d31a202

File tree

2 files changed

+7
-0
lines changed

2 files changed

+7
-0
lines changed

demo/quant/quant_post/quant_post.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@
2929
add_arg('model_filename', str, None, "model file name")
3030
add_arg('params_filename', str, None, "params file name")
3131
add_arg('algo', str, 'hist', "calibration algorithm")
32+
add_arg('round_type', str, 'round', "The method of converting the quantized weights.")
3233
add_arg('hist_percent', float, 0.9999, "The percentile of algo:hist")
3334
add_arg('bias_correction', bool, False, "Whether to use bias correction")
3435
add_arg('ce_test', bool, False, "Whether to CE test.")
@@ -74,6 +75,7 @@ def quantize(args):
7475
batch_size=args.batch_size,
7576
batch_nums=args.batch_num,
7677
algo=args.algo,
78+
round_type=args.round_type,
7779
hist_percent=args.hist_percent,
7880
bias_correction=args.bias_correction)
7981

paddleslim/quant/quanter.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -325,6 +325,7 @@ def quant_post_static(
325325
batch_nums=None,
326326
scope=None,
327327
algo='hist',
328+
round_type='round',
328329
hist_percent=0.9999,
329330
bias_correction=False,
330331
quantizable_op_type=["conv2d", "depthwise_conv2d", "mul"],
@@ -380,6 +381,9 @@ def quant_post_static(
380381
makes the mse loss minimal. Use one batch of data for mse is enough. If
381382
algo='avg', use the average of abs_max values to get the scale factor. If
382383
algo='abs_max', use abs_max method to get the scale factor. Default: 'hist'.
384+
round_type(str, optional): The method of converting the quantized weights value
385+
from float to int. Currently supports ['round', 'adaround'] methods.
386+
Default is `round`, which is rounding nearest to the nearest whole number.
383387
hist_percent(float, optional): The percentile of histogram for algo hist.Default:0.9999.
384388
bias_correction(bool, optional): Bias correction method of https://arxiv.org/abs/1810.05723.
385389
Default: False.
@@ -420,6 +424,7 @@ def quant_post_static(
420424
batch_nums=batch_nums,
421425
scope=scope,
422426
algo=algo,
427+
round_type=round_type,
423428
hist_percent=hist_percent,
424429
bias_correction=bias_correction,
425430
quantizable_op_type=quantizable_op_type,

0 commit comments

Comments
 (0)