Skip to content

Commit 46751fe

Browse files
authored
Method of adding hardware (PaddlePaddle#1041)
* fix opt's cmd for sparse model * add hardware * Remove redundant functions and adjust tests' file * Remove redundant functions and adjust tests' file
1 parent cb64200 commit 46751fe

File tree

4 files changed

+21
-94
lines changed

4 files changed

+21
-94
lines changed

paddleslim/analysis/__init__.py

Lines changed: 4 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -16,18 +16,10 @@
1616
from .latency import LatencyEvaluator, TableLatencyEvaluator
1717
from .latency_predictor import LatencyPredictor, TableLatencyPredictor
1818
from .parse_ops import get_key_from_op
19-
from ._utils import save_cls_model, save_det_model, save_seg_model
19+
from ._utils import save_cls_model, save_det_model
2020

2121
__all__ = [
22-
'flops',
23-
'dygraph_flops',
24-
'model_size',
25-
'LatencyEvaluator',
26-
'TableLatencyEvaluator',
27-
"LatencyPredictor",
28-
"TableLatencyPredictor",
29-
"get_key_from_op",
30-
"save_cls_model",
31-
"save_det_model",
32-
"save_seg_model",
22+
'flops', 'dygraph_flops', 'model_size', 'LatencyEvaluator',
23+
'TableLatencyEvaluator', "LatencyPredictor", "TableLatencyPredictor",
24+
"get_key_from_op", "save_cls_model", "save_det_model"
3325
]

paddleslim/analysis/_utils.py

Lines changed: 5 additions & 50 deletions
Original file line numberDiff line numberDiff line change
@@ -20,8 +20,8 @@
2020
import subprocess
2121
import time
2222
__all__ = [
23-
"save_cls_model", "save_det_model", "save_seg_model", "nearest_interpolate",
24-
"opt_model", "load_predictor"
23+
"save_cls_model", "save_det_model", "nearest_interpolate", "opt_model",
24+
"load_predictor"
2525
]
2626

2727

@@ -30,8 +30,7 @@ def opt_model(opt="paddle_lite_opt",
3030
param_file='',
3131
optimize_out_type='protobuf',
3232
valid_targets='arm',
33-
enable_fp16=False,
34-
sparse_ratio=0):
33+
enable_fp16=False):
3534
assert os.path.exists(model_file) and os.path.exists(
3635
param_file), f'{model_file} or {param_file} does not exist.'
3736
save_dir = f'./opt_models_tmp/{os.getpid()}_{time.time()}'
@@ -40,15 +39,13 @@ def opt_model(opt="paddle_lite_opt",
4039

4140
assert optimize_out_type in ['protobuf', 'naive_buffer']
4241
if optimize_out_type == 'protobuf':
43-
model_out = os.path.join(save_dir, 'pbmodel')
42+
model_out = save_dir
4443
else:
4544
model_out = os.path.join(save_dir, 'model')
4645

4746
enable_fp16 = str(enable_fp16).lower()
48-
sparse_model = True if sparse_ratio > 0 else False
49-
sparse_threshold = max(sparse_ratio - 0.1, 0.1)
5047

51-
cmd = f'{opt} --model_file={model_file} --param_file={param_file} --optimize_out_type={optimize_out_type} --optimize_out={model_out} --valid_targets={valid_targets} --enable_fp16={enable_fp16} --sparse_model={sparse_model} --sparse_threshold={sparse_threshold}'
48+
cmd = f'{opt} --model_file={model_file} --param_file={param_file} --optimize_out_type={optimize_out_type} --optimize_out={model_out} --valid_targets={valid_targets} --enable_fp16={enable_fp16} --sparse_model=true --sparse_threshold=0.4'
5249
print(f'commands:{cmd}')
5350
m = subprocess.Popen(
5451
cmd, stdout=subprocess.PIPE, stderr=subprocess.PIPE, shell=True)
@@ -169,48 +166,6 @@ def save_det_model(model,
169166
return model_file, param_file
170167

171168

172-
def save_seg_model(model, input_shape, save_dir, data_type):
173-
if data_type == 'fp32':
174-
paddle.jit.save(
175-
model,
176-
path=os.path.join(save_dir, 'fp32model'),
177-
input_spec=[
178-
paddle.static.InputSpec(
179-
shape=input_shape, dtype='float32', name='x'),
180-
])
181-
model_file = os.path.join(save_dir, 'fp32model.pdmodel')
182-
param_file = os.path.join(save_dir, 'fp32model.pdiparams')
183-
184-
else:
185-
save_dir = os.path.join(save_dir, 'int8model')
186-
quant_config = {
187-
'weight_preprocess_type': None,
188-
'activation_preprocess_type': None,
189-
'weight_quantize_type': 'channel_wise_abs_max',
190-
'activation_quantize_type': 'moving_average_abs_max',
191-
'weight_bits': 8,
192-
'activation_bits': 8,
193-
'dtype': 'int8',
194-
'window_size': 10000,
195-
'moving_rate': 0.9,
196-
'quantizable_layer_type': ['Conv2D', 'Linear'],
197-
}
198-
quantizer = paddleslim.QAT(config=quant_config)
199-
quantizer.quantize(model)
200-
quantizer.save_quantized_model(
201-
model,
202-
save_dir,
203-
input_spec=[
204-
paddle.static.InputSpec(
205-
shape=input_shape, dtype='float32')
206-
])
207-
208-
model_file = f'{save_dir}.pdmodel'
209-
param_file = f'{save_dir}.pdiparams'
210-
211-
return model_file, param_file
212-
213-
214169
def nearest_interpolate(features, data):
215170
def distance(x, y):
216171
x = np.array(x)

paddleslim/analysis/latency_predictor.py

Lines changed: 10 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,7 @@ class LatencyPredictor(object):
3939
"""Base class of latency predictor.
4040
"""
4141

42-
def predict_latency(self, model):
42+
def predict(self, model):
4343
"""Get latency of model. It is an abstract method.
4444
4545
Args:
@@ -64,6 +64,7 @@ class TableLatencyPredictor(LatencyPredictor):
6464
Args:
6565
table_file(str): The path of file that records the device latency of operators.
6666
"""
67+
hardware_list = ['SD625', 'SD710']
6768

6869
def __init__(self, table_file='SD710'):
6970
self.table_file = table_file
@@ -72,11 +73,14 @@ def __init__(self, table_file='SD710'):
7273
self.threads = None
7374
self.predictor_state = False
7475
self.predictor = {}
75-
self.hardware_list = ['SD625', 'SD710']
7676
self._initial_table()
7777

78+
@classmethod
79+
def add_hardware(cls, hardware):
80+
cls.hardware_list.append(hardware)
81+
7882
def _initial_table(self):
79-
if self.table_file in self.hardware_list:
83+
if self.table_file in TableLatencyPredictor.hardware_list:
8084
self.hardware = self.table_file
8185
self.threads = 4
8286
self.table_file = f'{self.hardware}_threads_4_power_mode_0.pkl'
@@ -88,7 +92,7 @@ def _initial_table(self):
8892

8993
assert os.path.exists(
9094
self.table_file
91-
), f'{self.table_file} does not exist. If you want to use our table files, please set \'table_file\' in {self.hardware_list}'
95+
), f'{self.table_file} does not exist. If you want to use our table files, please set \'table_file\' in {TableLatencyPredictor.hardware_list}'
9296
with open(self.table_file, 'rb') as f:
9397
self.table_dict = pickle.load(f)
9498

@@ -123,6 +127,8 @@ def _preload_predictor(self, data_type='fp32'):
123127
]
124128
op_dir = self.table_file.split('.')[0] + '_batchsize_1'
125129
for op_type in op_types:
130+
if data_type == 'fp32' and op_type == 'calib':
131+
continue
126132
model = load_predictor(op_type, op_dir, data_type)
127133
key = op_type
128134
if 'conv2d' in op_type:
@@ -141,8 +147,6 @@ def predict(self,
141147
model_file(str), param_file(str): The inference model(*.pdmodel, *.pdiparams).
142148
data_type(str): Data type, fp32, fp16 or int8.
143149
threads(int): Threads num.
144-
sparse_ratio(float): The ratio of unstructured pruning.
145-
prune_ratio(float): The ration of structured pruning.
146150
input_shape(list): Generally, the input shape is confirmed when saving the inference model and the parameter is only effective for input shape that has variable length.
147151
Returns:
148152
latency(float): The latency of the model.

tests/test_latency_predictor.py

Lines changed: 2 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@
1919
from paddleslim.analysis import LatencyPredictor, TableLatencyPredictor
2020
from paddle.vision.models import mobilenet_v1, mobilenet_v2
2121
from paddle.nn import Conv2D, BatchNorm2D, ReLU, LayerNorm
22-
from paddleslim.analysis._utils import opt_model, save_cls_model, save_seg_model, save_det_model
22+
from paddleslim.analysis._utils import opt_model, save_cls_model, save_det_model
2323

2424

2525
def channel_shuffle(x, groups):
@@ -276,7 +276,7 @@ def test_case5(self):
276276
paddle.disable_static()
277277
model = mobilenet_v1()
278278
predictor = TableLatencyPredictor(table_file='SD710')
279-
model_file, param_file = save_seg_model(
279+
model_file, param_file = save_cls_model(
280280
model,
281281
input_shape=[1, 3, 224, 224],
282282
save_dir="./inference_model",
@@ -367,30 +367,6 @@ def test_case9(self):
367367

368368
class TestCase10(unittest.TestCase):
369369
def test_case10(self):
370-
paddle.disable_static()
371-
model = ModelCase1()
372-
predictor = LatencyPredictor()
373-
model_file, param_file = save_seg_model(
374-
model,
375-
input_shape=[1, 116, 28, 28],
376-
save_dir="./inference_model",
377-
data_type='int8')
378-
pbmodel_file = opt_model(
379-
model_file=model_file,
380-
param_file=param_file,
381-
optimize_out_type='protobuf')
382-
383-
paddle.enable_static()
384-
with open(pbmodel_file, "rb") as f:
385-
fluid_program = paddle.fluid.framework.Program.parse_from_string(
386-
f.read())
387-
graph = paddleslim.core.GraphWrapper(fluid_program)
388-
graph_keys = predictor._get_key_info_from_graph(graph=graph)
389-
assert len(graph_keys) > 0
390-
391-
392-
class TestCase11(unittest.TestCase):
393-
def test_case11(self):
394370
paddle.disable_static()
395371
model = mobilenet_v2()
396372
model2 = ModelCase6()

0 commit comments

Comments
 (0)