Skip to content

Commit d98c773

Browse files
authored
[LatencyPredictor] add hardware (PaddlePaddle#1089)
* Add rk3288 predictor * fix some bugs for sparse conv2d.
1 parent d01db56 commit d98c773

File tree

3 files changed

+9
-7
lines changed

3 files changed

+9
-7
lines changed

paddleslim/analysis/extract_features.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -61,7 +61,7 @@ def get_features_from_paramkey(param_key, op_type, data_type):
6161
if quant_bits not in param_key:
6262
return None
6363

64-
weight = re.search(r'weight=(\(\d*, \d*, \d*, \d*\))',
64+
weight = re.search(r'weight=(\(\d*, -?\d*, \d*, \d*\))',
6565
param_key).group().split('=')[-1].strip(
6666
'('
6767
')').split(', ')

paddleslim/analysis/latency_predictor.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -68,7 +68,7 @@ class TableLatencyPredictor(LatencyPredictor):
6868
Args:
6969
table_file(str): The path of file that records the device latency of operators.
7070
"""
71-
hardware_list = ['SD625', 'SD710']
71+
hardware_list = ['SD625', 'SD710', 'RK3288']
7272

7373
def __init__(self, table_file='SD710'):
7474
self.table_file = table_file

paddleslim/analysis/parse_ops.py

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,10 @@ def get_key_from_op(op):
2424
if op_type == 'sparse_conv2d':
2525
out_shape = op.all_outputs()[0].shape()
2626
in_shape = op.inputs('Input')[0].shape()
27-
weight_shape = (out_shape[1], in_shape[1], 1, 1)
27+
if in_shape:
28+
weight_shape = (out_shape[1], in_shape[1], 1, 1)
29+
else:
30+
weight_shape = (out_shape[1], -1, 1, 1)
2831
NonZeroWeights = op.inputs('NonZeroWeights')[0].shape()[0]
2932

3033
stride = op.attr('strides')[1]
@@ -147,14 +150,13 @@ def get_key_from_op(op):
147150

148151
elif op_type == 'stack':
149152
data = op.all_inputs()
150-
X = "["
153+
X = ""
151154
for x in data:
152155
X += f"{x.shape()}"
153-
X += "]"
154156
axis = op.attr('axis')
155157
out_shape = op.all_outputs()[0].shape()
156158

157-
param_key = f'{op_type} X={X} axis={axis} out={out_shape}'
159+
param_key = f'{op_type} in={X} axis={axis} out={out_shape}'
158160

159161
elif op_type == 'exp':
160162
in_shape = op.all_inputs()[-1].shape()
@@ -219,7 +221,7 @@ def get_key_from_op(op):
219221
in_shape2 = op.all_inputs()[1].shape()
220222
out_shape = op.all_outputs()[0].shape()
221223

222-
param_key = f'{op_type} in={in_shape1} in={in_shape2} out={out_shape}'
224+
param_key = f'{op_type} X={in_shape1} Y={in_shape2} out={out_shape}'
223225

224226
elif op_type in ['calib', 'floor']:
225227
in_shape = op.all_inputs()[-1].shape()

0 commit comments

Comments
 (0)