Skip to content

Commit 22cee26

Browse files
authored
Fix ppminilm export bug (#1596)
* fix ppminilm export bug * update export script * update readme
1 parent a5f8a3e commit 22cee26

File tree

2 files changed

+8
-5
lines changed

2 files changed

+8
-5
lines changed

examples/model_compression/pp-minilm/README.md

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -58,7 +58,9 @@ PP-MiniLM 压缩方案以面向预训练模型的任务无关知识蒸馏(Task-a
5858
- batch sizes: 16, 32, 64;
5959
- learning rates: 3e-5, 5e-5, 1e-4
6060

61-
2.量化后比量化前模型参数量多了 0.1M 是因为保存了 scale 值。
61+
2.量化后比量化前模型参数量多了 0.1M 是因为保存了 scale 值;
62+
63+
3.性能测试的条件是:batch_size: 32, max_seq_len: 128。
6264

6365
**方案流程**
6466

@@ -197,7 +199,7 @@ export MODEL_PATH=ppminilm-6l-768h
197199
export LR=1e-4
198200
export BS=32
199201

200-
python export_model.py --task_name ${TASK_NAME} --output_dir ${MODEL_PATH}/models/${TASK_NAME}/${LR}_${BS}/
202+
python export_model.py --task_name ${TASK_NAME} --model_path ${MODEL_PATH}/models/${TASK_NAME}/${LR}_${BS}/
201203
```
202204

203205
静态图(部署)模型路径与动态图模型的路径相同,文件名为 `inference.pdmodel` , `inference.pdiparams``inference.pdiparams.info`

examples/model_compression/pp-minilm/finetuning/export_model.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,7 @@ def parse_args():
3434
help="The name of the task to train selected in the list: " +
3535
", ".join(METRIC_CLASSES.keys()), )
3636
parser.add_argument(
37-
"--output_dir",
37+
"--model_path",
3838
default="best_clue_model",
3939
type=str,
4040
help="The output directory where the model predictions and checkpoints will be written.",
@@ -50,9 +50,10 @@ def parse_args():
5050

5151

5252
def do_export(args):
53-
save_path = os.path.join(args.output_dir, "inference")
54-
model = PPMiniLMForSequenceClassification.from_pretrained(args.output_dir)
53+
save_path = os.path.join(os.path.dirname(args.model_path), "inference")
54+
model = PPMiniLMForSequenceClassification.from_pretrained(args.model_path)
5555
is_text_pair = True
56+
args.task_name = args.task_name.lower()
5657
if args.task_name in ('tnews', 'iflytek', 'cluewsc2020'):
5758
is_text_pair = False
5859
model.to_static(

0 commit comments

Comments
 (0)