Skip to content

Commit f01fc4f

Browse files
authored
fix demo for now (PaddlePaddle#1057)
1 parent 374a5f7 commit f01fc4f

File tree

3 files changed

+4
-3
lines changed

3 files changed

+4
-3
lines changed

demo/auto-compression/README.md

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -54,7 +54,7 @@ python tools/export_model.py \
5454
```
5555
cd PaddleSlim/demo/auto-compression/
5656
```
57-
使用[eval.py](../quant/quant_post/eval.py)脚本得到模型的分类精度:
57+
使用[eval.py](../quant/quant_post/eval.py)脚本得到模型的分类精度,压缩后的模型也可以使用同一个脚本测试精度
5858
```
5959
python ../quant/quant_post/eval.py --model_path infermodel_mobilenetv2 --model_name inference.pdmodel --params_name inference.pdiparams
6060
```
@@ -95,7 +95,7 @@ python demo_imagenet.py \
9595

9696
### 3.3 进行剪枝蒸馏策略融合压缩
9797
注意:本示例为对BERT模型进行ASP稀疏。
98-
首先参考[脚本](https://github.com/PaddlePaddle/PaddleNLP/tree/develop/examples/language_model/bert#%E9%A2%84%E6%B5%8B)得到可部署的模型。
98+
首先参考[脚本](https://github.com/PaddlePaddle/PaddleNLP/tree/develop/examples/language_model/bert#%E9%A2%84%E6%B5%8B)得到可部署的模型,或者下载SST-2数据集上的示例模型[SST-2-BERT](https://paddle-qa.bj.bcebos.com/PaddleSlim_datasets/static_bert_models.tar.gz)
9999
剪枝蒸馏压缩示例脚本为[demo_glue.py](./demo_glue.py),使用接口``paddleslim.auto_compression.AutoCompression``对模型进行压缩。运行命令为:
100100
```
101101
python demo_glue.py \

demo/auto-compression/demo_imagenet.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -97,7 +97,7 @@ def eval_function(exe, compiled_test_program, test_feed_names, test_fetch_list):
9797
strategy_config=compress_config,
9898
train_config=train_config,
9999
train_dataloader=train_dataloader,
100-
eval_callback=eval_function,
100+
eval_callback=eval_function if 'HyperParameterOptimization' not in compress_config else None,
101101
devices=args.devices)
102102

103103
ac.compress()

paddleslim/quant/quant_post_hpo.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -273,6 +273,7 @@ def quantize(cfg):
273273

274274
global g_min_emd_loss
275275
### if eval_function is not None, use eval function provided by user.
276+
### TODO(ceci3): fix eval_function
276277
if g_quant_config.eval_function is not None:
277278
emd_loss = g_quant_config.eval_function()
278279
else:

0 commit comments

Comments
 (0)