Skip to content

Commit d2ecc86

Browse files
authored
fix ptq hpo (PaddlePaddle#1021)
1 parent a00f830 commit d2ecc86

File tree

2 files changed

+61
-47
lines changed

2 files changed

+61
-47
lines changed

demo/quant/quant_post_hpo/quant_post_hpo.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -55,7 +55,7 @@ def gen():
5555
save_model_filename='__model__',
5656
save_params_filename='__params__',
5757
quantizable_op_type=["conv2d", "depthwise_conv2d", "mul"],
58-
weight_quantize_type='channel_wise_abs_max',
58+
weight_quantize_type=['channel_wise_abs_max'],
5959
runcount_limit=args.max_model_quant_count)
6060

6161
def main():

paddleslim/quant/quant_post_hpo.py

Lines changed: 60 additions & 46 deletions
Original file line numberDiff line numberDiff line change
@@ -189,10 +189,10 @@ def eval_quant_model():
189189
valid_data_num = 0
190190
max_eval_data_num = 200
191191
if g_quant_config.eval_sample_generator is not None:
192-
feed_dict=False
192+
feed_dict = False
193193
eval_dataloader = g_quant_config.eval_sample_generator
194194
else:
195-
feed_dict=True
195+
feed_dict = True
196196
eval_dataloader = g_quant_config.eval_dataloader
197197
for i, data in enumerate(eval_dataloader()):
198198
with paddle.static.scope_guard(float_scope):
@@ -236,12 +236,20 @@ def eval_quant_model():
236236
def quantize(cfg):
237237
"""model quantize job"""
238238
algo = cfg["algo"] if 'algo' in cfg else g_quant_config.algo[0][0]
239-
hist_percent = cfg["hist_percent"] if "hist_percent" in cfg else g_quant_config.hist_percent[0][0]
240-
bias_correct = cfg["bias_correct"] if "bias_correct" in cfg else g_quant_config.bias_correct[0][0]
241-
batch_size = cfg["batch_size"] if "batch_size" in cfg else g_quant_config.batch_size[0][0]
242-
batch_num = cfg["batch_num"] if "batch_num" in cfg else g_quant_config.batch_num[0][0]
243-
weight_quantize_type = cfg["weight_quantize_type"] if "weight_quantize_type" in cfg else g_quant_config.weight_quantize_type[0]
244-
print(hist_percent, bias_correct, batch_size, batch_num, weight_quantize_type)
239+
hist_percent = cfg[
240+
"hist_percent"] if "hist_percent" in cfg else g_quant_config.hist_percent[
241+
0][0]
242+
bias_correct = cfg[
243+
"bias_correct"] if "bias_correct" in cfg else g_quant_config.bias_correct[
244+
0][0]
245+
batch_size = cfg[
246+
"batch_size"] if "batch_size" in cfg else g_quant_config.batch_size[0][
247+
0]
248+
batch_num = cfg[
249+
"batch_num"] if "batch_num" in cfg else g_quant_config.batch_num[0][0]
250+
weight_quantize_type = cfg[
251+
"weight_quantize_type"] if "weight_quantize_type" in cfg else g_quant_config.weight_quantize_type[
252+
0]
245253

246254
quant_post( \
247255
executor=g_quant_config.executor, \
@@ -279,34 +287,35 @@ def quantize(cfg):
279287
return emd_loss
280288

281289

282-
def quant_post_hpo(executor,
283-
place,
284-
model_dir,
285-
quantize_model_path,
286-
train_sample_generator=None,
287-
eval_sample_generator=None,
288-
train_dataloader=None,
289-
eval_dataloader=None,
290-
eval_function=None,
291-
model_filename=None,
292-
params_filename=None,
293-
save_model_filename='__model__',
294-
save_params_filename='__params__',
295-
scope=None,
296-
quantizable_op_type=["conv2d", "depthwise_conv2d", "mul"],
297-
is_full_quantize=False,
298-
weight_bits=8,
299-
activation_bits=8,
300-
weight_quantize_type=['channel_wise_abs_max'],
301-
algo=["KL", "hist", "avg", "mse"],
302-
bias_correct=[True, False],
303-
hist_percent=[0.98, 0.999], ### uniform sample in list.
304-
batch_size=[10, 30], ### uniform sample in list.
305-
batch_num=[10, 30], ### uniform sample in list.
306-
optimize_model=False,
307-
is_use_cache_file=False,
308-
cache_dir="./temp_post_training",
309-
runcount_limit=30):
290+
def quant_post_hpo(
291+
executor,
292+
place,
293+
model_dir,
294+
quantize_model_path,
295+
train_sample_generator=None,
296+
eval_sample_generator=None,
297+
train_dataloader=None,
298+
eval_dataloader=None,
299+
eval_function=None,
300+
model_filename=None,
301+
params_filename=None,
302+
save_model_filename='__model__',
303+
save_params_filename='__params__',
304+
scope=None,
305+
quantizable_op_type=["conv2d", "depthwise_conv2d", "mul"],
306+
is_full_quantize=False,
307+
weight_bits=8,
308+
activation_bits=8,
309+
weight_quantize_type=['channel_wise_abs_max'],
310+
algo=["KL", "hist", "avg", "mse"],
311+
bias_correct=[True, False],
312+
hist_percent=[0.98, 0.999], ### uniform sample in list.
313+
batch_size=[10, 30], ### uniform sample in list.
314+
batch_num=[10, 30], ### uniform sample in list.
315+
optimize_model=False,
316+
is_use_cache_file=False,
317+
cache_dir="./temp_post_training",
318+
runcount_limit=30):
310319
"""
311320
The function utilizes static post training quantization method to
312321
quantize the fp32 model. It uses calibrate data to calculate the
@@ -360,25 +369,27 @@ def quant_post_hpo(executor,
360369

361370
global g_quant_config
362371
g_quant_config = QuantConfig(
363-
executor, place, model_dir, quantize_model_path, algo, hist_percent,
372+
executor, place, model_dir, quantize_model_path, algo, hist_percent,
364373
bias_correct, batch_size, batch_num, train_sample_generator,
365374
eval_sample_generator, train_dataloader, eval_dataloader, eval_function,
366-
model_filename, params_filename,
367-
save_model_filename, save_params_filename, scope, quantizable_op_type,
368-
is_full_quantize, weight_bits, activation_bits, weight_quantize_type,
369-
optimize_model, is_use_cache_file, cache_dir)
375+
model_filename, params_filename, save_model_filename,
376+
save_params_filename, scope, quantizable_op_type, is_full_quantize,
377+
weight_bits, activation_bits, weight_quantize_type, optimize_model,
378+
is_use_cache_file, cache_dir)
370379
cs = ConfigurationSpace()
371380

372381
hyper_params = []
373382

374383
if 'hist' in algo:
375384
hist_percent = UniformFloatHyperparameter(
376-
"hist_percent", hist_percent[0], hist_percent[1], default_value=hist_percent[0])
385+
"hist_percent",
386+
hist_percent[0],
387+
hist_percent[1],
388+
default_value=hist_percent[0])
377389
hyper_params.append(hist_percent)
378390

379391
if len(algo) > 1:
380-
algo = CategoricalHyperparameter(
381-
"algo", algo, default_value=algo[0])
392+
algo = CategoricalHyperparameter("algo", algo, default_value=algo[0])
382393
hyper_params.append(algo)
383394
else:
384395
algo = algo[0]
@@ -397,7 +408,10 @@ def quant_post_hpo(executor,
397408
weight_quantize_type = weight_quantize_type[0]
398409
if len(batch_size) > 1:
399410
batch_size = UniformIntegerHyperparameter(
400-
"batch_size", batch_size[0], batch_size[1], default_value=batch_size[0])
411+
"batch_size",
412+
batch_size[0],
413+
batch_size[1],
414+
default_value=batch_size[0])
401415
hyper_params.append(batch_size)
402416
else:
403417
batch_size = batch_size[0]
@@ -407,7 +421,7 @@ def quant_post_hpo(executor,
407421
"batch_num", batch_num[0], batch_num[1], default_value=batch_num[0])
408422
hyper_params.append(batch_num)
409423
else:
410-
batch_num = batch_num[0]
424+
batch_num = batch_num[0]
411425

412426
if len(hyper_params) == 0:
413427
quant_post( \

0 commit comments

Comments
 (0)