Skip to content

Commit 70a3642

Browse files
authored
add ptq data-free method (PaddlePaddle#1026)
* add ptq data-free method
1 parent d31a202 commit 70a3642

File tree

1 file changed

+112
-0
lines changed

1 file changed

+112
-0
lines changed
Lines changed: 112 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,112 @@
1+
import paddle
2+
from paddle.fluid.framework import IrGraph
3+
from paddle.fluid import core
4+
from paddle.fluid.contrib.slim.quantization import QuantizationTransformPass, AddQuantDequantPass, QuantizationFreezePass
5+
6+
7+
def post_quant_fake(executor,
8+
model_dir,
9+
model_filename=None,
10+
params_filename=None,
11+
save_model_path=None,
12+
quantizable_op_type=["conv2d", "depthwise_conv2d", "mul"],
13+
is_full_quantize=False,
14+
activation_bits=8,
15+
weight_bits=8):
16+
"""
17+
Utilizing post training quantization methon to quantize the FP32 model,
18+
and it not uses calibrate data and the fake model cannot be used in practice.
19+
Usage:
20+
paddle.enable_static()
21+
place = paddle.CPUPlace()
22+
exe = paddle.static.Executor(place)
23+
post_quant_fake(executor=exe,
24+
model_dir='./inference_model/MobileNet/',
25+
model_filename='model',
26+
params_filename='params',
27+
save_model_path='fake_quant')
28+
"""
29+
activation_quantize_type = 'range_abs_max'
30+
weight_quantize_type = 'channel_wise_abs_max'
31+
_dynamic_quantize_op_type = ['lstm']
32+
_weight_supported_quantizable_op_type = QuantizationTransformPass._supported_quantizable_op_type
33+
_act_supported_quantizable_op_type = AddQuantDequantPass._supported_quantizable_op_type
34+
_support_quantize_op_type = list(
35+
set(_weight_supported_quantizable_op_type +
36+
_act_supported_quantizable_op_type + _dynamic_quantize_op_type))
37+
_place = executor.place
38+
_scope = paddle.static.global_scope()
39+
if is_full_quantize:
40+
_quantizable_op_type = _support_quantize_op_type
41+
else:
42+
_quantizable_op_type = quantizable_op_type
43+
for op_type in _quantizable_op_type:
44+
assert op_type in _support_quantize_op_type, \
45+
op_type + " is not supported for quantization."
46+
47+
_program, _feed_list, _fetch_list = paddle.fluid.io.load_inference_model(
48+
model_dir,
49+
executor,
50+
model_filename=model_filename,
51+
params_filename=params_filename)
52+
53+
graph = IrGraph(core.Graph(_program.desc), for_test=True)
54+
55+
# use QuantizationTransformPass to insert fake_quant/fake_dequantize op
56+
major_quantizable_op_types = []
57+
for op_type in _weight_supported_quantizable_op_type:
58+
if op_type in _quantizable_op_type:
59+
major_quantizable_op_types.append(op_type)
60+
transform_pass = QuantizationTransformPass(
61+
scope=_scope,
62+
place=_place,
63+
weight_bits=weight_bits,
64+
activation_bits=activation_bits,
65+
activation_quantize_type=activation_quantize_type,
66+
weight_quantize_type=weight_quantize_type,
67+
quantizable_op_type=major_quantizable_op_types)
68+
69+
for sub_graph in graph.all_sub_graphs():
70+
# Insert fake_quant/fake_dequantize op must in test graph, so
71+
# set per graph's _for_test is True.
72+
sub_graph._for_test = True
73+
transform_pass.apply(sub_graph)
74+
75+
# use AddQuantDequantPass to insert fake_quant_dequant op
76+
minor_quantizable_op_types = []
77+
for op_type in _act_supported_quantizable_op_type:
78+
if op_type in _quantizable_op_type:
79+
minor_quantizable_op_types.append(op_type)
80+
add_quant_dequant_pass = AddQuantDequantPass(
81+
scope=_scope,
82+
place=_place,
83+
quantizable_op_type=minor_quantizable_op_types)
84+
85+
for sub_graph in graph.all_sub_graphs():
86+
sub_graph._for_test = True
87+
add_quant_dequant_pass.apply(sub_graph)
88+
89+
# apply QuantizationFreezePass, and obtain the final quant model
90+
freeze_pass = QuantizationFreezePass(
91+
scope=_scope,
92+
place=_place,
93+
weight_bits=weight_bits,
94+
activation_bits=activation_bits,
95+
weight_quantize_type=weight_quantize_type,
96+
quantizable_op_type=major_quantizable_op_types)
97+
98+
for sub_graph in graph.all_sub_graphs():
99+
sub_graph._for_test = True
100+
freeze_pass.apply(sub_graph)
101+
102+
_program = graph.to_program()
103+
104+
paddle.fluid.io.save_inference_model(
105+
dirname=save_model_path,
106+
model_filename=model_filename,
107+
params_filename=params_filename,
108+
feeded_var_names=_feed_list,
109+
target_vars=_fetch_list,
110+
executor=executor,
111+
main_program=_program)
112+
print("The quantized model is saved in: " + save_model_path)

0 commit comments

Comments
 (0)