Skip to content

Commit 374a5f7

Browse files
zzjjayceci3
andauthored
add an interface to predict the latency of compressed model (PaddlePaddle#1045)
* add an interface to predict the latency of compressed model * just for rerun * delete temporary model files Co-authored-by: ceci3 <[email protected]>
1 parent 2ef31f9 commit 374a5f7

File tree

2 files changed

+121
-1
lines changed

2 files changed

+121
-1
lines changed
Lines changed: 121 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,121 @@
1+
import os
2+
import paddle
3+
from paddleslim.analysis import TableLatencyPredictor
4+
from .prune_model import get_sparse_model, get_prune_model
5+
from .fake_ptq import post_quant_fake
6+
import shutil
7+
8+
9+
def predict_compressed_model(model_file, param_file, hardware='SD710'):
10+
"""
11+
Evaluating the latency of the model under various compression strategies.
12+
Args:
13+
model_file(str), param_file(str): The inference model to be compressed.
14+
hardware(str): Target device.
15+
Returns:
16+
latency_dict(dict): The latency latency of the model under various compression strategies.
17+
"""
18+
latency_dict = {}
19+
20+
model_filename = model_file.split('/')[-1]
21+
param_filename = param_file.split('/')[-1]
22+
23+
predictor = TableLatencyPredictor(hardware)
24+
latency = predictor.predict(
25+
model_file=model_file, param_file=param_file, data_type='fp32')
26+
latency_dict.update({'origin_fp32': latency})
27+
paddle.enable_static()
28+
place = paddle.CPUPlace()
29+
exe = paddle.static.Executor(place)
30+
post_quant_fake(
31+
exe,
32+
model_dir=os.path.dirname(model_file),
33+
model_filename=model_filename,
34+
params_filename=param_filename,
35+
save_model_path='quant_model',
36+
quantizable_op_type=["conv2d", "depthwise_conv2d", "mul"],
37+
is_full_quantize=False,
38+
activation_bits=8,
39+
weight_bits=8)
40+
quant_model_file = os.path.join('quant_model', model_filename)
41+
quant_param_file = os.path.join('quant_model', param_filename)
42+
43+
latency = predictor.predict(
44+
model_file=quant_model_file,
45+
param_file=quant_param_file,
46+
data_type='int8')
47+
latency_dict.update({f'origin_int8': latency})
48+
49+
for prune_ratio in [0.3, 0.4, 0.5, 0.6]:
50+
get_prune_model(
51+
model_file=model_file,
52+
param_file=param_file,
53+
ratio=prune_ratio,
54+
save_path='prune_model')
55+
prune_model_file = os.path.join('prune_model', model_filename)
56+
prune_param_file = os.path.join('prune_model', param_filename)
57+
58+
latency = predictor.predict(
59+
model_file=prune_model_file,
60+
param_file=prune_param_file,
61+
data_type='fp32')
62+
latency_dict.update({f'prune_{prune_ratio}_fp32': latency})
63+
64+
post_quant_fake(
65+
exe,
66+
model_dir='prune_model',
67+
model_filename=model_filename,
68+
params_filename=param_filename,
69+
save_model_path='quant_model',
70+
quantizable_op_type=["conv2d", "depthwise_conv2d", "mul"],
71+
is_full_quantize=False,
72+
activation_bits=8,
73+
weight_bits=8)
74+
quant_model_file = os.path.join('quant_model', model_filename)
75+
quant_param_file = os.path.join('quant_model', param_filename)
76+
77+
latency = predictor.predict(
78+
model_file=quant_model_file,
79+
param_file=quant_param_file,
80+
data_type='int8')
81+
latency_dict.update({f'prune_{prune_ratio}_int8': latency})
82+
83+
for sparse_ratio in [0.70, 0.75, 0.80, 0.85, 0.90, 0.95]:
84+
get_sparse_model(
85+
model_file=model_file,
86+
param_file=param_file,
87+
ratio=sparse_ratio,
88+
save_path='sparse_model')
89+
sparse_model_file = os.path.join('sparse_model', model_filename)
90+
sparse_param_file = os.path.join('sparse_model', param_filename)
91+
92+
latency = predictor.predict(
93+
model_file=sparse_model_file,
94+
param_file=sparse_param_file,
95+
data_type='fp32')
96+
latency_dict.update({f'sparse_{sparse_ratio}_fp32': latency})
97+
98+
post_quant_fake(
99+
exe,
100+
model_dir='sparse_model',
101+
model_filename=model_filename,
102+
params_filename=param_filename,
103+
save_model_path='quant_model',
104+
quantizable_op_type=["conv2d", "depthwise_conv2d", "mul"],
105+
is_full_quantize=False,
106+
activation_bits=8,
107+
weight_bits=8)
108+
quant_model_file = os.path.join('quant_model', model_filename)
109+
quant_param_file = os.path.join('quant_model', param_filename)
110+
111+
latency = predictor.predict(
112+
model_file=quant_model_file,
113+
param_file=quant_param_file,
114+
data_type='int8')
115+
latency_dict.update({f'sparse_{prune_ratio}_int8': latency})
116+
117+
# Delete temporary model files
118+
shutil.rmtree('./quant_model')
119+
shutil.rmtree('./prune_model')
120+
shutil.rmtree('./sparse_model')
121+
return latency_dict

paddleslim/auto_compression/utils/prune_model.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,6 @@ def get_sparse_model(model_file, param_file, ratio, save_path):
2727

2828
folder = os.path.dirname(model_file)
2929
model_name = model_file.split('/')[-1]
30-
model_name = model_file.split('/')[-1]
3130
if param_file is None:
3231
param_name = None
3332
else:

0 commit comments

Comments
 (0)