Skip to content

Commit 1a9376a

Browse files
zzjjayminghaoBD
andauthored
[LatencyPredictor] Add new op and speed up prediction (PaddlePaddle#1014)
* add new op type and support fp16 model * preload predictors' model and speed up prediction * preload predictors' model * preload predictors' model * Modified the save path of TMP files Co-authored-by: minghaoBD <[email protected]>
1 parent 23cc74d commit 1a9376a

File tree

4 files changed

+59
-29
lines changed

4 files changed

+59
-29
lines changed

paddleslim/analysis/_utils.py

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@
1818
import paddle
1919
import paddleslim
2020
import subprocess
21-
import sklearn
21+
import time
2222
__all__ = [
2323
"save_cls_model", "save_det_model", "save_seg_model", "nearest_interpolate",
2424
"opt_model", "load_predictor"
@@ -29,10 +29,11 @@ def opt_model(opt="paddle_lite_opt",
2929
model_file='',
3030
param_file='',
3131
optimize_out_type='protobuf',
32-
valid_targets='arm'):
32+
valid_targets='arm',
33+
enable_fp16=False):
3334
assert os.path.exists(model_file) and os.path.exists(
3435
param_file), f'{model_file} or {param_file} does not exist.'
35-
save_dir = f'./opt_models_tmp/{os.getpid()}'
36+
save_dir = f'./opt_models_tmp/{os.getpid()}_{time.time()}'
3637
if not os.path.exists(save_dir):
3738
os.makedirs(save_dir)
3839

@@ -41,8 +42,8 @@ def opt_model(opt="paddle_lite_opt",
4142
model_out = os.path.join(save_dir, 'pbmodel')
4243
else:
4344
model_out = os.path.join(save_dir, 'model')
44-
45-
cmd = f'{opt} --model_file={model_file} --param_file={param_file} --optimize_out_type={optimize_out_type} --optimize_out={model_out} --valid_targets={valid_targets}'
45+
enable_fp16 = str(enable_fp16).lower()
46+
cmd = f'{opt} --model_file={model_file} --param_file={param_file} --optimize_out_type={optimize_out_type} --optimize_out={model_out} --valid_targets={valid_targets} --enable_fp16={enable_fp16}'
4647
print(f'commands:{cmd}')
4748
m = subprocess.Popen(
4849
cmd, stdout=subprocess.PIPE, stderr=subprocess.PIPE, shell=True)

paddleslim/analysis/extract_features.py

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -52,8 +52,13 @@ def get_features_from_paramkey(param_key, op_type, data_type):
5252
features = None
5353

5454
if 'conv2d' in op_type:
55-
flag_quant = 'quant=None' if data_type == 'fp32' else 'quant=True'
56-
if flag_quant not in param_key:
55+
if data_type == 'fp16':
56+
quant_bits = 'bit_length=16'
57+
elif data_type == 'int8':
58+
quant_bits = 'bit_length=8'
59+
else:
60+
quant_bits = 'bit_length=None'
61+
if quant_bits not in param_key:
5762
return None
5863

5964
weight = re.search(r'weight=(\(\d*, \d*, \d*, \d*\))',
@@ -178,7 +183,7 @@ def get_features_from_paramkey(param_key, op_type, data_type):
178183
'leaky_relu' in op_type or 'tanh' in op_type or 'swish' in op_type or
179184
'softmax' in op_type or 'hard_sigmoid' in op_type or
180185
'sigmoid' in op_type or 'gelu' in op_type or 'clip' in op_type or
181-
'shape' in op_type or 'interp_v2' in op_type):
186+
'shape' in op_type or 'interp_v2' in op_type or 'sqrt' in op_type):
182187

183188
inputs = re.search(r'in=(\((-?\d+,* *)+\))',
184189
param_key).group().split('=')[-1].strip(

paddleslim/analysis/latency_predictor.py

Lines changed: 32 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@
1616

1717
import os
1818
import pickle
19-
import time
19+
import shutil
2020
import subprocess
2121
from .parse_ops import get_key_from_op
2222
from .extract_features import get_data_from_tables, get_features_from_paramkey
@@ -71,15 +71,16 @@ def __init__(self, table_file='SD710'):
7171
self.hardware = None
7272
self.threads = None
7373
self.predictor_state = False
74+
self.predictor = {}
7475
self._initial_table()
7576

7677
def _initial_table(self):
7778
if self.table_file in ['SD625', 'SD710', 'SD845', 'SD865']:
7879
self.hardware = self.table_file
79-
if self.hardware in ['SD625', 'SD710']:
80-
self.predictor_state = True
8180
self.threads = 4
8281
self.table_file = f'{self.hardware}_threads_4_power_mode_0.pkl'
82+
if self.hardware in ['SD625', 'SD710']:
83+
self.predictor_state = True
8384
if not os.path.exists(self.table_file):
8485
subprocess.call(
8586
f'wget https://paddlemodels.bj.bcebos.com/PaddleSlim/analysis/{self.table_file}',
@@ -115,6 +116,19 @@ def _get_input_shape(self, graph):
115116
break
116117
return in_shape
117118

119+
def _preload_predictor(self, data_type='fp32'):
120+
op_types = [
121+
'depthwise_conv2d', 'conv2d', 'pool2d', 'matmul', 'elementwise_add',
122+
'elementwise_mul', 'concat', 'calib', 'swish'
123+
]
124+
op_dir = self.table_file.split('.')[0] + '_batchsize_1'
125+
for op_type in op_types:
126+
model = load_predictor(op_type, op_dir, data_type)
127+
key = op_type
128+
if 'conv2d' in op_type:
129+
key = f'{op_type}_{data_type}'
130+
self.predictor[key] = model
131+
118132
def predict(self,
119133
model_file,
120134
param_file,
@@ -125,22 +139,27 @@ def predict(self,
125139
126140
Args:
127141
model_file(str), param_file(str): The inference model(*.pdmodel, *.pdiparams).
128-
data_type(str): Data type, fp32 or int8. Default : fp32
142+
data_type(str): Data type, fp32, fp16 or int8.
129143
threads(int): threads num
130144
input_shape(list): Generally, the input shape is confirmed when saving the inference model and the parameter is only effective for input shape that has variable length.
131145
Returns:
132146
latency(float): The latency of the model.
133147
"""
134-
assert data_type in ['fp32', 'int8'
135-
], f'data_type must be one of [fp32, int8]'
148+
assert data_type in ['fp32', 'int8', 'fp16'
149+
], f'data_type must be one of [fp32, int8, fp16]'
136150

137151
if self.hardware and self.threads != threads:
138152
self._change_table(threads)
139153

154+
if self.predictor_state and f'conv2d_{data_type}' not in self.predictor:
155+
self._preload_predictor(data_type)
156+
157+
enable_fp16 = True if data_type == 'fp16' else False
140158
pbmodel_file = opt_model(
141159
model_file=model_file,
142160
param_file=param_file,
143-
optimize_out_type='protobuf', )
161+
optimize_out_type='protobuf',
162+
enable_fp16=enable_fp16)
144163

145164
paddle.enable_static()
146165
with open(pbmodel_file, "rb") as f:
@@ -176,7 +195,7 @@ def predict(self,
176195
warnings.warn("OperatorType\tCalledTimes")
177196
for key in new_op:
178197
warnings.warn(f"{key.ljust(15)}\t{new_op[key]}")
179-
198+
shutil.rmtree(os.path.dirname(pbmodel_file))
180199
return latency
181200

182201
def op_predictor(self, op_type, param_key, data_type):
@@ -185,18 +204,20 @@ def op_predictor(self, op_type, param_key, data_type):
185204
Args:
186205
op_type: The operator's type
187206
param_key: The operator's parameter information.
188-
data_type: Data type, fp32 or int8. Default : int8
207+
data_type: Data type, fp32 or int8.
189208
Returns:
190209
latency(float): The latency of the operator.
191210
"""
192211

193212
latency = 0.0
194-
op_dir = self.table_file.split('.')[0] + '_batchsize_1'
195213
if op_type in [
196214
'depthwise_conv2d', 'conv2d', 'pool2d', 'matmul',
197215
'elementwise_add', 'elementwise_mul', 'concat', 'calib', 'swish'
198216
]:
199-
predictor = load_predictor(op_type, op_dir, data_type)
217+
key = op_type
218+
if 'conv2d' in op_type:
219+
key = f'{op_type}_{data_type}'
220+
predictor = self.predictor[key]
200221
features = get_features_from_paramkey(param_key, op_type, data_type)
201222
latency = predictor.predict([features])
202223
else:

paddleslim/analysis/parse_ops.py

Lines changed: 13 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -24,25 +24,30 @@ def get_key_from_op(op):
2424
if 'conv2d' in op_type:
2525
out_shape = op.all_outputs()[0].shape()
2626
in_shape = op.all_inputs()[-1].shape()
27+
in_name = op.all_inputs()[1].name()
2728
weight_shape = op.all_inputs()[-2].shape()
28-
kernel = weight_shape[2]
29+
weight_shape = (out_shape[1], weight_shape[1], weight_shape[2], weight_shape[3])
30+
2931
stride = op.attr('strides')[1]
3032
padding = op.attr('paddings')[1]
3133
groups = op.attr('groups')
3234
dilation = op.attr('dilations')[1]
33-
int8 = op.attr('enable_int8')
35+
quant = op.attr('enable_int8')
3436
bit_length = op.attr('bit_length')
37+
if op.attr(in_name+'_fp16') == 'fp16':
38+
quant = True
39+
bit_length = 16
3540

36-
param_key = f'{op_type} in={in_shape} weight={weight_shape} out={out_shape} pad={padding} stride={stride} group={groups} dilation={dilation} quant={int8} bit_length={bit_length}'
41+
param_key = f'{op_type} in={in_shape} weight={weight_shape} out={out_shape} pad={padding} stride={stride} group={groups} dilation={dilation} quant={quant} bit_length={bit_length}'
3742

3843
elif op_type == 'matmul' or op_type == 'matmul_v2':
3944
X = op.all_inputs()[0].shape()
4045
Y = op.all_inputs()[1].shape()
4146
out_shape = op.all_outputs()[0].shape()
42-
int8 = op.attr('enable_int8')
47+
quant = op.attr('enable_int8')
4348
bit_length = op.attr('bit_length')
4449

45-
param_key = f'{op_type} X={X} Y={Y} out={out_shape} quant={int8} bit_length={bit_length}'
50+
param_key = f'{op_type} X={X} Y={Y} out={out_shape} quant={quant} bit_length={bit_length}'
4651

4752
elif 'batch_norm' in op_type or 'layer_norm' in op_type:
4853
out_shape = op.all_outputs()[-1].shape()
@@ -67,14 +72,12 @@ def get_key_from_op(op):
6772

6873
elif op_type in [
6974
'hard_swish', 'relu', 'leaky_relu', 'tanh', 'swish', 'softmax',
70-
'hard_sigmoid', 'sigmoid', 'gelu', 'clip', 'shape'
75+
'hard_sigmoid', 'sigmoid', 'gelu', 'clip', 'shape', 'sqrt'
7176
] or 'transpose' in op_type or 'interp_v2' in op_type:
7277
in_shape = op.all_inputs()[-1].shape()
78+
out_shape = op.all_outputs()[0].shape()
7379

74-
param_key = f'{op_type} in={in_shape}'
75-
in_shape = op.all_inputs()[-1].shape()
76-
77-
param_key = f'{op_type} in={in_shape}'
80+
param_key = f'{op_type} in={in_shape} out={out_shape}'
7881

7982
elif op_type in ['fill_constant', 'range', 'cast'] or 'expand' in op_type:
8083

0 commit comments

Comments
 (0)