Skip to content

Commit 00af227

Browse files
authored
Merge pull request #1844 from LemonNoel/deploy
[ERNIE-Health] Add export_model examples
2 parents a6e56e3 + 4a9a3db commit 00af227

File tree

8 files changed

+163
-68
lines changed

8 files changed

+163
-68
lines changed

docs/model_zoo/transformers.rst

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -273,6 +273,10 @@ Transformer预训练模型汇总
273273
| | | | 12-heads, 102M parameters. |
274274
| | | | Trained on Chinese text. |
275275
| +----------------------------------------------------------------------------------+--------------+-----------------------------------------+
276+
| |``ernie-health-chinese`` | Chinese | 12-layer, 768-hidden, |
277+
| | | | 12-heads, 102M parameters. |
278+
| | | | Trained on Chinese medical corpus. |
279+
| +----------------------------------------------------------------------------------+--------------+-----------------------------------------+
276280
| |``junnyu/hfl-chinese-electra-180g-base-discriminator`` | Chinese | Discriminator, 12-layer, 768-hidden, |
277281
| | | | 12-heads, 102M parameters. |
278282
| | | | Trained on 180g Chinese text. |
@@ -858,3 +862,4 @@ Reference
858862
- Bao, Siqi, et al. "Plato-2: Towards building an open-domain chatbot via curriculum learning." arXiv preprint arXiv:2006.16779 (2020).
859863
- Yang, Zhilin, et al. "Xlnet: Generalized autoregressive pretraining for language understanding." arXiv preprint arXiv:1906.08237 (2019).
860864
- Cui, Yiming, et al. "Pre-training with whole word masking for chinese bert." arXiv preprint arXiv:1906.08101 (2019).
865+
- Wang, Quan, et al. “Building Chinese Biomedical Language Models via Multi-Level Text Discrimination.” arXiv preprint arXiv:2110.07244 (2021).
Lines changed: 52 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -1,16 +1,22 @@
11
# 使用医疗领域预训练模型Fine-tune完成中文医疗语言理解任务
22

3-
近年来,预训练语言模型(Pre-trained Language Model,PLM)逐渐成为自然语言处理(Natural Language Processing,NLP)的主流方法。这类模型可以利用大规模的未标注语料进行训练,得到的模型在下游NLP任务上效果明显提升,在通用领域和特定领域均有广泛应用。在医疗领域,早期的做法是在预先训练好的通用语言模型上进行Fine-tune。后来的研究发现直接使用医疗相关语料学习到的预训练语言模型在医疗文本任务上的效果更好,采用的模型结构也从早期的BERT演变为更新的RoBERTa、ALBERT和ELECTRA
3+
医疗领域存在大量的专业知识和医学术语,人类经过长时间的学习才能成为一名优秀的医生。那机器如何才能“读懂”医疗文献呢?尤其是面对电子病历、生物医疗文献中存在的大量非结构化、非标准化文本,计算机是无法直接使用、处理的。这就需要自然语言处理(Natural Language Processing,NLP)技术大展身手了
44

5-
本示例展示了中文医疗预训练模型eHealth([Building Chinese Biomedical Language Models via Multi-Level Text Discrimination](https://arxiv.org/abs/2110.07244))如何Fine-tune完成中文医疗语言理解任务。
5+
近年来,预训练语言模型(Pre-trained Language Model,PLM)逐渐成为自然语言处理的主流方法。这类模型利用大规模的未标注语料进行训练,得到的模型在下游NLP任务上的效果有着明显提升,在通用领域和特定领域均有广泛应用。在医疗领域,早期的做法是在预先训练好的通用语言模型上进行 Fine-tune。后来的研究发现直接使用医疗相关语料学习到的预训练语言模型在医疗文本任务上的效果更好,采用的模型结构也从早期的BERT演变为更新的 RoBERTa、ALBERT和ELECTRA。
6+
7+
本示例展示了中文医疗预训练模型 ERNIE-Health([Building Chinese Biomedical Language Models via Multi-Level Text Discrimination](https://arxiv.org/abs/2110.07244))如何 Fine-tune完成中文医疗语言理解任务。
68

79
## 模型介绍
810

9-
本项目针对中文医疗语言理解任务,开源了中文医疗预训练模型eHealth(简写`chinese-ehealth`)。eHealth使用了医患对话、科普文章、病历档案、临床病理学教材等脱敏中文语料进行预训练,通过预训练任务设计来学习词级别和句级别的文本信息。该模型的整体结构与ELECTRA相似,包括生成器和判别器两部分。 而Fine-tune过程只用到了判别器模块,由12层Transformer网络组成。
11+
本项目针对中文医疗语言理解任务,开源了中文医疗预训练模型 ERNIE-Health(模型名称`ernie-health-chinese`)。ERNIE-Health模型依托于百度知识增强语义理解框架 ERNIE,以超越人类医学专家水平的成绩登顶中文医疗信息处理权威榜单 CBLUE 冠军, 验证了 ERNIE 在医疗行业应用的重要价值。
12+
13+
![CBLUERank](https://user-images.githubusercontent.com/25607475/160394225-04f75498-ce1a-4665-85f7-d495815eed51.png)
14+
15+
ERNIE-Health 依托百度文心 ERNIE 先进的知识增强预训练语言模型打造, 通过医疗知识增强技术进一步学习海量的医疗数据, 精准地掌握了专业的医学知识。ERNIE-Health 利用医疗实体掩码策略对专业术语等实体级知识学习, 学会了海量的医疗实体知识。同时,通过医疗问答匹配任务学习病患病状描述与医生专业治疗方案的对应关系,获得了医疗实体知识之间的内在联系。ERNIE-Health 共学习了 60 多万的医疗专业术语和 4000 多万的医疗专业问答数据,大幅提升了对医疗专业知识的理解和建模能力。此外,ERNIE-Health 还探索了多级语义判别预训练任务,提升了模型对医疗知识的学习效率。该模型的整体结构与 ELECTRA 相似,包括生成器和判别器两部分。 而 Fine-tune 过程只用到了判别器模块,由 12 层 Transformer 网络组成。
1016

1117
## 数据集介绍
1218

13-
本项目使用了中文医学语言理解测评([Chinese Biomedical Language Understanding Evaluation,CBLUE](https://github.com/CBLUEbenchmark/CBLUE)数据集,其包括医学文本信息抽取(实体识别、关系抽取)、医学术语归一化、医学文本分类、医学句子关系判定和医学问答共5大类任务8个子任务。
19+
本项目使用了中文医学语言理解测评([Chinese Biomedical Language Understanding Evaluation,CBLUE](https://github.com/CBLUEbenchmark/CBLUE)1.0 版本数据集,这是国内首个面向中文医疗文本处理的多任务榜单,涵盖了医学文本信息抽取(实体识别、关系抽取)、医学术语归一化、医学文本分类、医学句子关系判定和医学问答共5大类任务8个子任务。其数据来源分布广泛,包括医学教材、电子病历、临床试验公示以及互联网用户真实查询等。该榜单一经推出便受到了学界和业界的广泛关注,已逐渐发展成为检验AI系统中文医疗信息处理能力的“金标准”
1420

1521
* CMeEE:中文医学命名实体识别
1622
* CMeIE:中文医学文本实体关系抽取
@@ -21,7 +27,7 @@
2127
* KUAKE-QTR:医疗搜索查询词-页面标题相关性
2228
* KUAKE-QQR:医疗搜索查询词-查询词相关性
2329

24-
更多信息可参考CBLUE的[github](https://github.com/CBLUEbenchmark/CBLUE/blob/main/README_ZH.md)。其中对于临床术语标准化任务(CHIP-CDN),我们按照eHealth中的方法通过检索将原多分类任务转换为了二分类任务,即给定一诊断原词和一诊断标准词,要求判定后者是否是前者对应的诊断标准词。本项目提供了检索处理后的CHIP-CDN数据集(简写`CHIP-CDN-2C`),且构建了基于该数据集的example代码。
30+
更多信息可参考CBLUE的[github](https://github.com/CBLUEbenchmark/CBLUE/blob/main/README_ZH.md)。其中对于临床术语标准化任务(CHIP-CDN),我们按照 ERNIE-Health 中的方法通过检索将原多分类任务转换为了二分类任务,即给定一诊断原词和一诊断标准词,要求判定后者是否是前者对应的诊断标准词。本项目提供了检索处理后的 CHIP-CDN 数据集(简写`CHIP-CDN-2C`),且构建了基于该数据集的example代码。
2531

2632
## 快速开始
2733

@@ -34,7 +40,14 @@ cblue/
3440
├── README.md # 使用说明
3541
├── train_classification.py # 分类任务训练评估脚本
3642
├── train_ner.py # 实体识别任务训练评估脚本
37-
└── train_spo.py # 关系抽取任务训练评估脚本
43+
├── train_spo.py # 关系抽取任务训练评估脚本
44+
└── export_model.py #动态图参数导出静态图参数脚本
45+
```
46+
47+
### 依赖安装
48+
49+
```shell
50+
pip install xlrd==1.2.0
3851
```
3952

4053
### 模型训练
@@ -43,28 +56,20 @@ cblue/
4356

4457
**训练参数设置(Training setup)及结果**
4558

46-
| Task | epochs | batch_size | learning_rate | max_seq_length | results |
47-
| --------- | :----: | :--------: | :-----------: | :------------: | :-----: |
48-
| CHIP-STS | 16 | 32 | 1e-4 | 96 | 0.88550 |
49-
| CHIP-CTC | 16 | 32 | 3e-5 | 160 | 0.82790 |
50-
| CHIP-CDN | 16 | 256 | 3e-5 | 32 | 0.76979 |
51-
| KUAKE-QQR | 16 | 32 | 6e-5 | 64 | 0.82364 |
52-
| KUAKE-QTR | 12 | 32 | 6e-5 | 64 | 0.69653 |
53-
| KUAKE-QIC | 4 | 32 | 6e-5 | 128 | 0.81176 |
54-
| CMeEE | 2 | 32 | 6e-5 | 128 | 0.66167 |
55-
| CMeIE | 100 | 12 | 6e-5 | 300 | 0.61385 |
56-
57-
#### 医疗文本分类任务
58-
59-
```shell
60-
$ unset CUDA_VISIBLE_DEVICES
61-
$ python -m paddle.distributed.launch --gpus "0,1,2,3" train_classification.py --dataset CHIP-CDN-2C --batch_size 256 --max_seq_length 32 --learning_rate 3e-5 --epochs 16
62-
```
59+
| Task | epochs | batch_size | learning_rate | max_seq_length | metric | results | results (fp16) |
60+
| --------- | :----: | :--------: | :-----------: | :------------: | :------: | :-----: | :------------: |
61+
| CHIP-STS | 4 | 16 | 1e-4 | 96 | Macro-F1 | 0.88550 | 0.85649 |
62+
| CHIP-CTC | 4 | 32 | 6e-5 | 160 | Macro-F1 | 0.84136 | 0.83514 |
63+
| CHIP-CDN | 16 | 256 | 3e-5 | 32 | F1 | 0.76979 | 0.76489 |
64+
| KUAKE-QQR | 2 | 32 | 6e-5 | 64 | Accuracy | 0.83865 | 0.84053 |
65+
| KUAKE-QTR | 4 | 32 | 6e-5 | 64 | Accuracy | 0.69722 | 0.69722 |
66+
| KUAKE-QIC | 4 | 32 | 6e-5 | 128 | Accuracy | 0.81483 | 0.82046 |
67+
| CMeEE | 2 | 32 | 6e-5 | 128 | Micro-F1 | 0.66120 | 0.66026 |
68+
| CMeIE | 100 | 12 | 6e-5 | 300 | Micro-F1 | 0.61385 | 0.60076 |
6369

6470
可支持配置的参数:
6571

6672
* `save_dir`:可选,保存训练模型的目录;默认保存在当前目录checkpoints文件夹下。
67-
* `dataset`:可选,CHIP-CDN-2C CHIP-CTC CHIP-STS KUAKE-QIC KUAKE-QTR KUAKE-QQR,默认为KUAKE-QIC数据集。
6873
* `max_seq_length`:可选,ELECTRA模型使用的最大序列长度,最大不能超过512, 若出现显存不足,请适当调低这一参数;默认为128。
6974
* `batch_size`:可选,批处理大小,请结合显存情况进行调整,若出现显存不足,请适当调低这一参数;默认为32。
7075
* `learning_rate`:可选,Fine-tune的最大学习率;默认为6e-5。
@@ -78,26 +83,41 @@ $ python -m paddle.distributed.launch --gpus "0,1,2,3" train_classification.py -
7883
* `seed`:可选,随机种子,默认为1000.
7984
* `device`: 选用什么设备进行训练,可选cpu或gpu。如使用gpu训练则参数gpus指定GPU卡号。
8085
* `use_amp`: 是否使用混合精度训练,默认为False。
81-
* `use_ema`: 是否使用Exponential Moving Average预测,默认为False。
8286

83-
#### 医疗命名实体识别任务
87+
**NOTE:**
88+
* 如需恢复模型训练,则可以设置`init_from_ckpt`, 如`init_from_ckpt=checkpoints/model_100/model_state.pdparams`
89+
* 使用动态图训练结束之后,还可以将动态图参数导出成静态图参数,具体代码见export_model.py。静态图参数保存在`output_path`指定路径中。
90+
运行方式:
8491

8592
```shell
86-
$ export CUDA_VISIBLE_DEVICES=0
87-
$ python train_ner.py --batch_size 32 --max_seq_length 128 --learning_rate 6e-5 --epochs 12
93+
python export_model.py --train_dataset CMeIE --params_path=./checkpoint/model_900/model_state.pdparams --output_path=./export
8894
```
8995

90-
#### 医疗关系抽取任务
96+
#### 医疗文本分类任务
97+
98+
```shell
99+
$ unset CUDA_VISIBLE_DEVICES
100+
$ python -m paddle.distributed.launch --gpus "0,1,2,3" train_classification.py --dataset CHIP-CDN-2C --batch_size 256 --max_seq_length 32 --learning_rate 3e-5 --epochs 16
101+
```
102+
103+
其他可支持配置的参数:
104+
105+
* `dataset`:可选,CHIP-CDN-2C CHIP-CTC CHIP-STS KUAKE-QIC KUAKE-QTR KUAKE-QQR,默认为KUAKE-QIC数据集。
106+
107+
#### 医疗命名实体识别任务(CMeEE)
91108

92109
```shell
93110
$ export CUDA_VISIBLE_DEVICES=0
94-
$ python train_spo.py --batch_size 12 --max_seq_length 300 --learning_rate 6e-5 --epochs 100
111+
$ python train_ner.py --batch_size 32 --max_seq_length 128 --learning_rate 6e-5 --epochs 12
95112
```
96113

97-
### 依赖安装
114+
#### 医疗关系抽取任务(CMeIE)
98115

99116
```shell
100-
pip install xlrd==1.2.0
117+
$ export CUDA_VISIBLE_DEVICES=0
118+
$ python train_spo.py --batch_size 12 --max_seq_length 300 --learning_rate 6e-5 --epochs 100
101119
```
102120

103121
[1] CBLUE: A Chinese Biomedical Language Understanding Evaluation Benchmark [pdf](https://arxiv.org/abs/2106.08087) [git](https://github.com/CBLUEbenchmark/CBLUE) [web](https://tianchi.aliyun.com/specials/promotion/2021chinesemedicalnlpleaderboardchallenge)
122+
123+
[2] Wang, Quan, et al. “Building Chinese Biomedical Language Models via Multi-Level Text Discrimination.” arXiv preprint arXiv:2110.07244 (2021). [pdf](https://arxiv.org/abs/2110.07244)
Lines changed: 78 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,78 @@
1+
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
import argparse
16+
import os
17+
18+
import paddle
19+
from paddlenlp.transformers import ElectraForSequenceClassification
20+
from model import ElectraForBinaryTokenClassification, ElectraForSPO
21+
22+
# yapf: disable
23+
parser = argparse.ArgumentParser()
24+
parser.add_argument('--train_dataset', choices=['KUAKE-QIC', 'KUAKE-QQR', 'KUAKE-QTR', 'CHIP-STS', 'CHIP-CTC', 'CHIP-CDN-2C', 'CMeEE', 'CMeIE'],
25+
required=True, type=str, help='The name of dataset used for training.')
26+
parser.add_argument('--params_path', type=str, required=True, default='./checkpoint/model_state.pdparams', help='The path to model parameters to be loaded.')
27+
parser.add_argument('--output_path', type=str, default='./export', help='The path of model parameter in static graph to be saved.')
28+
args = parser.parse_args()
29+
# yapf: enable
30+
31+
NUM_CLASSES = {
32+
'CHIP-CDN-2C': 2,
33+
'CHIP-STS': 2,
34+
'CHIP-CTC': 44,
35+
'KUAKE-QQR': 3,
36+
'KUAKE-QTR': 4,
37+
'KUAKE-QIC': 11,
38+
'CMeEE': [33, 5],
39+
'CMeIE': 44
40+
}
41+
42+
if __name__ == "__main__":
43+
if args.train_dataset == 'CMeEE':
44+
model = ElectraForBinaryTokenClassification.from_pretrained(
45+
'ernie-health-chinese', num_classes=NUM_CLASSES[args.train_dataset])
46+
elif args.train_dataset == 'CMeIE':
47+
model = ElectraForSPO.from_pretrained(
48+
'ernie-health-chinese', num_classes=NUM_CLASSES[args.train_dataset])
49+
else:
50+
model = ElectraForSequenceClassification.from_pretrained(
51+
'ernie-health-chinese',
52+
num_classes=NUM_CLASSES[args.train_dataset],
53+
activation='tanh')
54+
55+
if args.params_path and os.path.isfile(args.params_path):
56+
state_dict = paddle.load(args.params_path)
57+
model.set_dict(state_dict)
58+
print("Loaded parameters from %s" % args.params_path)
59+
model.eval()
60+
61+
# Convert to static graph with specific input description
62+
input_spec = [
63+
paddle.static.InputSpec(
64+
shape=[None, None], dtype="int64"), # input_ids
65+
paddle.static.InputSpec(
66+
shape=[None, None], dtype="int64"), # token_type_ids
67+
paddle.static.InputSpec(
68+
shape=[None, None], dtype="int64") # position_ids
69+
]
70+
if args.train_dataset in ['CMeEE', 'CMeIE']:
71+
input_spec.append(
72+
paddle.static.InputSpec(
73+
shape=[None, None], dtype="float32")) # masks
74+
75+
model = paddle.jit.to_static(model, input_spec=input_spec)
76+
# Save in static graph model.
77+
save_path = os.path.join(args.output_path, "inference")
78+
paddle.jit.save(model, save_path)

0 commit comments

Comments
 (0)