Skip to content

Commit eb5cf37

Browse files
authored
Merge pull request #1782 from LemonNoel/ner_spo_debug
[EHealth] Finetune examples of Classification, NER and SPO
2 parents 83743cc + 40f45e2 commit eb5cf37

File tree

15 files changed

+1717
-37
lines changed

15 files changed

+1717
-37
lines changed
Lines changed: 86 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,86 @@
1+
# 使用医疗领域预训练模型Fine-tune完成中文医疗语言理解任务
2+
3+
近年来,预训练语言模型(Pre-trained Language Model,PLM)逐渐成为自然语言处理(Natural Language Processing,NLP)的主流方法。这类模型可以利用大规模的未标注语料进行训练,得到的模型在下游NLP任务上效果明显提升,在通用领域和特定领域均有广泛应用。在医疗领域,早期的做法是在预先训练好的通用语言模型上进行Fine-tune。后来的研究发现直接使用医疗相关语料学习到的预训练语言模型在医疗文本任务上的效果更好,采用的模型结构也从早期的BERT演变为更新的RoBERTa、ALBERT和ELECTRA。
4+
5+
本示例展示了中文医疗预训练模型eHealth([Building Chinese Biomedical Language Models via Multi-Level Text Discrimination](https://arxiv.org/abs/2110.07244))如何Fine-tune完成中文医疗语言理解任务。
6+
7+
## 模型介绍
8+
9+
本项目针对中文医疗语言理解任务,开源了中文医疗预训练模型eHealth(简写`chinese-ehealth`)。eHealth使用了医患对话、科普文章、病历档案、临床病理学教材等脱敏中文语料进行预训练,通过预训练任务设计来学习词级别和句级别的文本信息。该模型的整体结构与ELECTRA相似,包括生成器和判别器两部分。 而Fine-tune过程只用到了判别器模块,由12层Transformer网络组成。
10+
11+
## 数据集介绍
12+
13+
本项目使用了中文医学语言理解测评([Chinese Biomedical Language Understanding Evaluation,CBLUE](https://github.com/CBLUEbenchmark/CBLUE))数据集,其包括医学文本信息抽取(实体识别、关系抽取)、医学术语归一化、医学文本分类、医学句子关系判定和医学问答共5大类任务8个子任务。
14+
15+
* CMeEE:中文医学命名实体识别
16+
* CMeIE:中文医学文本实体关系抽取
17+
* CHIP-CDN:临床术语标准化任务
18+
* CHIP-CTC:临床试验筛选标准短文本分类
19+
* CHIP-STS:平安医疗科技疾病问答迁移学习
20+
* KUAKE-QIC:医疗搜索检索词意图分类
21+
* KUAKE-QTR:医疗搜索查询词-页面标题相关性
22+
* KUAKE-QQR:医疗搜索查询词-查询词相关性
23+
24+
更多信息可参考CBLUE的[github](https://github.com/CBLUEbenchmark/CBLUE/blob/main/README_ZH.md)。其中对于临床术语标准化任务(CHIP-CDN),我们按照eHealth中的方法通过检索将原多分类任务转换为了二分类任务,即给定一诊断原词和一诊断标准词,要求判定后者是否是前者对应的诊断标准词。本项目提供了检索处理后的CHIP-CDN数据集(简写`CHIP-CDN-2C`),且构建了基于该数据集的example代码。
25+
26+
## 快速开始
27+
28+
### 代码结构说明
29+
30+
以下是本项目主要代码结构及说明:
31+
32+
```text
33+
cblue/
34+
├── README.md # 使用说明
35+
├── train_classification.py # 分类任务训练评估脚本
36+
├── train_ner.py # 实体识别任务训练评估脚本
37+
└── train_spo.py # 关系抽取任务训练评估脚本
38+
```
39+
40+
### 模型训练
41+
42+
我们按照任务类别划分,同时提供了8个任务的样例代码。可以运行下边的命令,在训练集上进行训练,并在开发集上进行验证。
43+
44+
```shell
45+
$ unset CUDA_VISIBLE_DEVICES
46+
$ python -m paddle.distributed.launch --gpus "0,1,2,3" train.py --dataset CHIP-CDN-2C --batch_size 256 --max_seq_length 32 --learning_rate 3e-5 --epochs 16
47+
```
48+
49+
### 训练参数设置(Training setup)及结果
50+
51+
| Task | epochs | batch_size | learning_rate | max_seq_length | results |
52+
| --------- | :----: | :--------: | :-----------: | :------------: | :-----: |
53+
| CHIP-STS | 16 | 32 | 1e-4 | 96 | 0.88550 |
54+
| CHIP-CTC | 16 | 32 | 3e-5 | 160 | 0.82790 |
55+
| CHIP-CDN | 16 | 256 | 3e-5 | 32 | 0.76979 |
56+
| KUAKE-QQR | 16 | 32 | 6e-5 | 64 | 0.82364 |
57+
| KUAKE-QTR | 12 | 32 | 6e-5 | 64 | 0.69653 |
58+
| KUAKE-QIC | 4 | 32 | 6e-5 | 128 | 0.81176 |
59+
60+
61+
可支持配置的参数:
62+
63+
* `save_dir`:可选,保存训练模型的目录;默认保存在当前目录checkpoints文件夹下。
64+
* `dataset`:可选,CHIP-CDN-2C CHIP-CTC CHIP-STS KUAKE-QIC KUAKE-QTR KUAKE-QQR,默认为KUAKE-QIC数据集。
65+
* `max_seq_length`:可选,ELECTRA模型使用的最大序列长度,最大不能超过512, 若出现显存不足,请适当调低这一参数;默认为128。
66+
* `batch_size`:可选,批处理大小,请结合显存情况进行调整,若出现显存不足,请适当调低这一参数;默认为32。
67+
* `learning_rate`:可选,Fine-tune的最大学习率;默认为6e-5。
68+
* `weight_decay`:可选,控制正则项力度的参数,用于防止过拟合,默认为0.01。
69+
* `epochs`: 训练轮次,默认为3。
70+
* `valid_steps`: evaluate的间隔steps数,默认100。
71+
* `save_steps`: 保存checkpoints的间隔steps数,默认100。
72+
* `logging_steps`: 日志打印的间隔steps数,默认10。
73+
* `warmup_proption`:可选,学习率warmup策略的比例,如果0.1,则学习率会在前10%训练step的过程中从0慢慢增长到learning_rate, 而后再缓慢衰减,默认为0.1。
74+
* `init_from_ckpt`:可选,模型参数路径,热启动模型训练;默认为None。
75+
* `seed`:可选,随机种子,默认为1000.
76+
* `device`: 选用什么设备进行训练,可选cpu或gpu。如使用gpu训练则参数gpus指定GPU卡号。
77+
* `use_amp`: 是否使用混合精度训练,默认为False。
78+
* `use_ema`: 是否使用Exponential Moving Average预测,默认为False。
79+
80+
### 依赖安装
81+
82+
```shell
83+
pip install xlrd==1.2.0
84+
```
85+
86+
[1] CBLUE: A Chinese Biomedical Language Understanding Evaluation Benchmark [pdf](https://arxiv.org/abs/2106.08087) [git](https://github.com/CBLUEbenchmark/CBLUE) [web](https://tianchi.aliyun.com/specials/promotion/2021chinesemedicalnlpleaderboardchallenge)

examples/biomedical/cblue/model.py

Lines changed: 127 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,127 @@
1+
import paddle
2+
import paddle.nn as nn
3+
from paddlenlp.transformers import ElectraPretrainedModel
4+
5+
6+
class ElectraForBinaryTokenClassification(ElectraPretrainedModel):
7+
"""
8+
Electra Model with two linear layers on top of the hidden-states output layers,
9+
designed for token classification tasks with nesting.
10+
11+
Args:
12+
electra (:class:`ElectraModel`):
13+
An instance of ElectraModel.
14+
num_classes (list):
15+
The number of classes.
16+
dropout (float, optionl):
17+
The dropout probability for output of Electra.
18+
If None, use the same value as `hidden_dropout_prob' of 'ElectraModel`
19+
instance `electra`. Defaults to None.
20+
"""
21+
22+
def __init__(self, electra, num_classes, dropout=None):
23+
super(ElectraForBinaryTokenClassification, self).__init__()
24+
assert (len(num_classes) == 2)
25+
self.num_classes_oth = num_classes[0]
26+
self.num_classes_sym = num_classes[1]
27+
self.electra = electra
28+
self.dropout = nn.Dropout(dropout if dropout is not None else
29+
self.electra.config['hidden_dropout_prob'])
30+
self.classifier_oth = nn.Linear(self.electra.config['hidden_size'],
31+
self.num_classes_oth)
32+
self.classifier_sym = nn.Linear(self.electra.config['hidden_size'],
33+
self.num_classes_sym)
34+
self.init_weights()
35+
36+
def forward(self,
37+
input_ids=None,
38+
token_type_ids=None,
39+
position_ids=None,
40+
attention_mask=None):
41+
sequence_output = self.electra(input_ids, token_type_ids, position_ids,
42+
attention_mask)
43+
sequence_output = self.dropout(sequence_output)
44+
45+
logits_sym = self.classifier_sym(sequence_output)
46+
logits_oth = self.classifier_oth(sequence_output)
47+
return logits_oth, logits_sym
48+
49+
50+
class MultiHeadAttentionForSPO(nn.Layer):
51+
"""
52+
Multi-head attention layer for SPO task.
53+
"""
54+
55+
def __init__(self, embed_dim, num_heads, scale_value=768):
56+
super(MultiHeadAttentionForSPO, self).__init__()
57+
self.embed_dim = embed_dim
58+
self.num_heads = num_heads
59+
self.scale_value = scale_value**-0.5
60+
self.q_proj = nn.Linear(embed_dim, embed_dim * num_heads)
61+
self.k_proj = nn.Linear(embed_dim, embed_dim * num_heads)
62+
63+
def forward(self, query, key):
64+
q = self.q_proj(query)
65+
k = self.k_proj(query)
66+
q = paddle.reshape(q, shape=[0, 0, self.num_heads, self.embed_dim])
67+
k = paddle.reshape(k, shape=[0, 0, self.num_heads, self.embed_dim])
68+
q = paddle.transpose(q, perm=[0, 2, 1, 3])
69+
k = paddle.transpose(k, perm=[0, 2, 1, 3])
70+
scores = paddle.matmul(q, k, transpose_y=True)
71+
scores = paddle.scale(scores, scale=self.scale_value)
72+
return scores
73+
74+
75+
class ElectraForSPO(ElectraPretrainedModel):
76+
"""
77+
Electra Model with a linear layer on top of the hidden-states output
78+
layers for entity recognition, and a multi-head attention layer for
79+
relation classification.
80+
81+
Args:
82+
electra (:class:`ElectraModel`):
83+
An instance of ElectraModel.
84+
num_classes (int):
85+
The number of classes.
86+
dropout (float, optionl):
87+
The dropout probability for output of Electra.
88+
If None, use the same value as `hidden_dropout_prob' of 'ElectraModel`
89+
instance `electra`. Defaults to None.
90+
"""
91+
92+
def __init__(self, electra, num_classes, dropout=None):
93+
super(ElectraForSPO, self).__init__()
94+
self.num_classes = num_classes
95+
self.electra = electra
96+
self.dropout = nn.Dropout(dropout if dropout is not None else
97+
self.electra.config['hidden_dropout_prob'])
98+
self.classifier = nn.Linear(self.electra.config['hidden_size'], 2)
99+
self.span_attention = MultiHeadAttentionForSPO(
100+
self.electra.config['hidden_size'], num_classes)
101+
self.sigmoid = paddle.nn.Sigmoid()
102+
103+
def forward(self,
104+
input_ids=None,
105+
token_type_ids=None,
106+
position_ids=None,
107+
attention_mask=None):
108+
sequence_outputs, _, all_hidden_states = self.electra(
109+
input_ids,
110+
token_type_ids,
111+
position_ids,
112+
attention_mask,
113+
output_hidden_states=True)
114+
sequence_outputs = self.dropout(sequence_outputs)
115+
ent_logits = self.classifier(sequence_outputs)
116+
117+
subject_output = all_hidden_states[-2]
118+
cls_output = paddle.unsqueeze(sequence_outputs[:, 0, :], axis=1)
119+
subject_output = subject_output + cls_output
120+
121+
output_size = self.num_classes + self.electra.config['hidden_size']
122+
rel_logits = self.span_attention(sequence_outputs, subject_output)
123+
124+
ent_logits = self.sigmoid(ent_logits)
125+
rel_logits = self.sigmoid(rel_logits)
126+
127+
return ent_logits, rel_logits

0 commit comments

Comments
 (0)