Skip to content

Commit 380bce6

Browse files
authored
add ppdet auto compression demo (PaddlePaddle#1039)
1 parent 3cf6116 commit 380bce6

File tree

8 files changed

+269
-60
lines changed

8 files changed

+269
-60
lines changed
Lines changed: 63 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,63 @@
1+
# 使用预测模型进行量化训练示例
2+
3+
预测模型保存接口:
4+
动态图使用``paddle.jit.save``保存;
5+
静态图使用``paddle.static.save_inference_model``保存。
6+
7+
本示例将介绍如何使用PaddleDetection中预测模型进行蒸馏量化训练。
8+
9+
## 模型量化蒸馏训练流程
10+
11+
### 1. 准备COCO格式数据
12+
13+
参考[COCO数据准备文档](https://github.com/PaddlePaddle/PaddleDetection/blob/release/2.4/docs/tutorials/PrepareDataSet.md#coco%E6%95%B0%E6%8D%AE)
14+
15+
### 2. 准备需要量化的环境
16+
17+
- PaddlePaddle >= 2.2
18+
- PaddleDet >= 2.3
19+
20+
```shell
21+
pip install paddledet
22+
```
23+
24+
#### 3 准备待量化模型
25+
- 下载代码
26+
```
27+
git clone https://github.com/PaddlePaddle/PaddleDetection.git
28+
```
29+
- 导出预测模型
30+
```shell
31+
python tools/export_model.py -c configs/yolov3/yolov3_mobilenet_v1_270e_coco.yml -o weights=https://paddledet.bj.bcebos.com/models/yolov3_mobilenet_v1_270e_coco.pdparams
32+
```
33+
或直接下载:
34+
```shell
35+
wget https://paddledet.bj.bcebos.com/models/slim/yolov3_mobilenet_v1_270e_coco.tar
36+
tar -xf yolov3_mobilenet_v1_270e_coco.tar
37+
```
38+
39+
#### 2.4 测试模型精度
40+
拷贝``yolov3_mobilenet_v1_270e_coco``文件夹到``PaddleSlim/demo/auto-compression/``文件夹。
41+
```
42+
cd PaddleSlim/demo/auto-compression/
43+
```
44+
使用[demo_coco.py](../demo_coco.py)脚本得到模型的分类精度:
45+
```
46+
python3.7 ../demo_coco.py --model_dir=../yolov3_mobilenet_v1_270e_coco/ --model_filename=model.pdmodel --params_filename=model.pdiparams --eval=True
47+
```
48+
49+
### 3. 进行多策略融合压缩
50+
51+
每一个小章节代表一种多策略融合压缩,不代表需要串行执行。
52+
53+
### 3.1 进行量化蒸馏压缩
54+
蒸馏量化训练示例脚本为[demo_coco.py](../demo_coco.py),使用接口``paddleslim.auto_compression.AutoCompression``对模型进行量化训练。运行命令为:
55+
```
56+
python ../demo_coco.py \
57+
--model_dir='infermodel_mobilenetv2' \
58+
--model_filename='model.pdmodel' \
59+
--params_filename='./model.pdiparams' \
60+
--save_dir='./output/' \
61+
--devices='gpu' \
62+
--config_path='./yolov3_mbv1_qat_dis.yaml'
63+
```

demo/auto-compression/configs/PaddleDet/coco_dataset.yml

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -5,14 +5,13 @@ TrainDataset:
55
!COCODataSet
66
image_dir: train2017
77
anno_path: annotations/instances_train2017.json
8-
dataset_dir: dataset/coco
9-
data_fields: ['image', 'gt_bbox', 'gt_class', 'is_crowd']
8+
dataset_dir: dataset/coco/
109

1110
EvalDataset:
1211
!COCODataSet
1312
image_dir: val2017
1413
anno_path: annotations/instances_val2017.json
15-
dataset_dir: dataset/coco
14+
dataset_dir: dataset/coco/
1615

1716
TestDataset:
1817
!ImageFolder

demo/auto-compression/configs/PaddleDet/ppyoloe_reader.yml

Lines changed: 0 additions & 40 deletions
This file was deleted.
Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,15 @@
1+
_BASE_: [
2+
'./coco_dataset.yml',
3+
]
4+
5+
worker_num: 8
6+
7+
TestReader:
8+
inputs_def:
9+
image_shape: [3, 640, 640]
10+
sample_transforms:
11+
- Decode: {}
12+
- Resize: {target_size: [640, 640], keep_ratio: False, interp: 2}
13+
- NormalizeImage: {mean: [0.485, 0.456, 0.406], std: [0.229, 0.224, 0.225], is_scale: True}
14+
- Permute: {}
15+
batch_size: 4
Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,32 @@
1+
Distillation:
2+
distill_lambda: 1.0
3+
distill_loss: l2_loss
4+
distill_node_pair:
5+
- teacher_conv2d_84.tmp_0
6+
- conv2d_84.tmp_0
7+
- teacher_conv2d_85.tmp_0
8+
- conv2d_85.tmp_0
9+
- teacher_conv2d_86.tmp_0
10+
- conv2d_86.tmp_0
11+
merge_feed: true
12+
teacher_model_dir: ./yolov3_mobilenet_v1_270e_coco/
13+
teacher_model_filename: model.pdmodel
14+
teacher_params_filename: model.pdiparams
15+
Quantization:
16+
activation_bits: 8
17+
is_full_quantize: false
18+
not_quant_pattern:
19+
- skip_quant
20+
quantize_op_types:
21+
- conv2d
22+
- depthwise_conv2d
23+
weight_bits: 8
24+
TrainConfig:
25+
epochs: 1
26+
eval_iter: 1000
27+
learning_rate: 0.0001
28+
optimizer: SGD
29+
optim_args:
30+
weight_decay: 4.0e-05
31+
#origin_metric: 0.289
32+

demo/auto-compression/demo_coco.py

Lines changed: 147 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -2,15 +2,156 @@
22
import sys
33
sys.path[0] = os.path.join(os.path.dirname("__file__"), os.path.pardir)
44
import argparse
5-
5+
import functools
6+
from functools import partial
7+
import numpy as np
8+
import paddle
69
from ppdet.core.workspace import load_config, merge_config
710
from ppdet.core.workspace import create
11+
from ppdet.metrics import COCOMetric
12+
from paddleslim.auto_compression.config_helpers import load_config as load_slim_config
13+
from paddleslim.auto_compression import AutoCompression
14+
15+
paddle.enable_static()
16+
17+
from utility import add_arguments, print_arguments
18+
19+
parser = argparse.ArgumentParser(description=__doc__)
20+
add_arg = functools.partial(add_arguments, argparser=parser)
21+
22+
# yapf: disable
23+
add_arg('model_dir', str, None, "inference model directory.")
24+
add_arg('model_filename', str, None, "inference model filename.")
25+
add_arg('params_filename', str, None, "inference params filename.")
26+
add_arg('save_dir', str, 'output', "directory to save compressed model.")
27+
add_arg('devices', str, 'gpu', "which device used to compress.")
28+
add_arg('batch_size', int, 1, "train batch size.")
29+
add_arg('config_path', str, None, "path of compression strategy config.")
30+
add_arg('eval', bool, False, "whether to run evaluation.")
31+
# yapf: enable
32+
33+
34+
def reader_wrapper(reader):
35+
def gen():
36+
for data in reader:
37+
yield {
38+
"image": data['image'],
39+
'im_shape': data['im_shape'],
40+
'scale_factor': data['scale_factor']
41+
}
42+
43+
return gen
44+
45+
46+
def eval():
47+
dataset = reader_cfg['EvalDataset']
48+
val_loader = create('TestReader')(dataset,
49+
reader_cfg['worker_num'],
50+
return_list=True)
51+
52+
place = paddle.CUDAPlace(0) if args.devices == 'gpu' else paddle.CPUPlace()
53+
exe = paddle.static.Executor(place)
54+
55+
val_program, feed_target_names, fetch_targets = paddle.fluid.io.load_inference_model(
56+
args.model_dir,
57+
exe,
58+
model_filename=args.model_filename,
59+
params_filename=args.params_filename)
60+
clsid2catid = {v: k for k, v in dataset.catid2clsid.items()}
61+
62+
anno_file = dataset.get_anno()
63+
metric = COCOMetric(
64+
anno_file=anno_file, clsid2catid=clsid2catid, bias=0, IouType='bbox')
65+
for batch_id, data in enumerate(val_loader):
66+
data_new = {k: np.array(v) for k, v in data.items()}
67+
outs = exe.run(val_program,
68+
feed={
69+
'image': data['image'],
70+
'im_shape': data['im_shape'],
71+
'scale_factor': data['scale_factor']
72+
},
73+
fetch_list=fetch_targets,
74+
return_numpy=False)
75+
res = {}
76+
for out in outs:
77+
v = np.array(out)
78+
if len(v.shape) > 1:
79+
res['bbox'] = v
80+
else:
81+
res['bbox_num'] = v
82+
83+
metric.update(data_new, res)
84+
if batch_id % 100 == 0:
85+
print('Eval iter:', batch_id)
86+
metric.accumulate()
87+
metric.log()
88+
metric.reset()
89+
90+
91+
def eval_function(exe, compiled_test_program, test_feed_names, test_fetch_list):
92+
clsid2catid = {v: k for k, v in dataset.catid2clsid.items()}
93+
94+
anno_file = dataset.get_anno()
95+
metric = COCOMetric(
96+
anno_file=anno_file, clsid2catid=clsid2catid, bias=1, IouType='bbox')
97+
for batch_id, data in enumerate(val_loader):
98+
data_new = {k: np.array(v) for k, v in data.items()}
99+
outs = exe.run(compiled_test_program,
100+
feed={
101+
'image': data['image'],
102+
'im_shape': data['im_shape'],
103+
'scale_factor': data['scale_factor']
104+
},
105+
fetch_list=test_fetch_list,
106+
return_numpy=False)
107+
res = {}
108+
for out in outs:
109+
v = np.array(out)
110+
if len(v.shape) > 1:
111+
res['bbox'] = v
112+
else:
113+
res['bbox_num'] = v
114+
115+
metric.update(data_new, res)
116+
if batch_id % 100 == 0:
117+
print('Eval iter:', batch_id)
118+
metric.accumulate()
119+
metric.log()
120+
map_res = metric.get_results()
121+
metric.reset()
122+
return map_res['bbox'][0]
123+
124+
125+
if __name__ == '__main__':
126+
args = parser.parse_args()
127+
print_arguments(args)
128+
paddle.enable_static()
129+
reader_cfg = load_config('./configs/PaddleDet/yolo_reader.yml')
130+
if args.eval:
131+
eval()
132+
sys.exit(0)
133+
134+
compress_config, train_config = load_slim_config(args.config_path)
8135

9-
cfg = load_config('./configs/PaddleDet/ppyoloe_reader.yml')
136+
train_loader = create('TestReader')(reader_cfg['TrainDataset'],
137+
reader_cfg['worker_num'],
138+
return_list=True)
139+
dataset = reader_cfg['EvalDataset']
140+
val_loader = create('TestReader')(reader_cfg['EvalDataset'],
141+
reader_cfg['worker_num'],
142+
return_list=True)
10143

11-
print(cfg)
144+
train_dataloader = reader_wrapper(train_loader)
12145

13-
coco_loader = create('TestReader')(cfg['TrainDataset'], cfg['worker_num'])
146+
ac = AutoCompression(
147+
model_dir=args.model_dir,
148+
model_filename=args.model_filename,
149+
params_filename=args.params_filename,
150+
save_dir=args.save_dir,
151+
strategy_config=compress_config,
152+
train_config=train_config,
153+
train_dataloader=train_dataloader,
154+
eval_callback=eval_function,
155+
devices=args.devices)
14156

15-
for data in coco_loader:
16-
print(data.keys())
157+
ac.compress()

paddleslim/auto_compression/compressor.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -315,7 +315,8 @@ def _start_train(self, train_program_info, test_program_info):
315315
_logger.info("epoch: {}, batch: {}, loss: {}".format(
316316
epoch_id, batch_id, np_probs_float))
317317

318-
if batch_id % int(self.train_config.eval_iter) == 0:
318+
if batch_id % int(
319+
self.train_config.eval_iter) == 0 and batch_id != 0:
319320
if self.eval_function is not None:
320321

321322
# GMP pruner step 3: update params before summrizing sparsity, saving model or evaluation.

paddleslim/quant/quanter.py

Lines changed: 8 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -198,7 +198,8 @@ def quant_aware(program,
198198
optimizer_func=None,
199199
executor=None,
200200
onnx_format=False,
201-
return_program=False):
201+
return_program=False,
202+
draw_graph=False):
202203
"""Add quantization and dequantization operators to "program"
203204
for quantization training or testing.
204205
@@ -241,6 +242,8 @@ def quant_aware(program,
241242
initialization. Default is None.
242243
return_program(bool): If user want return value is a Program rather than Compiled Program, This argument should be set True.
243244
Default is False.
245+
draw_graph(bool): whether to draw graph when quantization is initialized. In order to prevent cycle,
246+
the ERNIE model needs to be set to True. Default is False.
244247
Returns:
245248
paddle.static.CompiledProgram | paddle.static.Program: Program with quantization and dequantization ``operators``
246249
"""
@@ -308,15 +311,10 @@ def quant_aware(program,
308311
VARS_MAPPING_TABLE))
309312
save_dict(main_graph.out_node_mapping_table)
310313

311-
main_graph.draw('./', 'graph.pdf')
312-
#remove_ctr_vars = set()
313-
#from paddle.fluid.framework import IrVarNode
314-
#all_var_nodes = {IrVarNode(node) for node in main_graph.nodes() if node.is_var()}
315-
#for node in all_var_nodes:
316-
# print("node: ", node)
317-
# if node.is_ctrl_var():
318-
# remove_ctr_vars.add(node)
319-
#self.safe_remove_nodes(remove_ctr_vars)
314+
# TDOD: remove it.
315+
if draw_graph:
316+
main_graph.draw('./', 'graph.pdf')
317+
320318
if for_test or return_program:
321319
quant_program = main_graph.to_program()
322320
else:

0 commit comments

Comments
 (0)