Skip to content

Commit 15f0aa8

Browse files
committed
[ehealth] add sequence classification example
1 parent 275e62a commit 15f0aa8

File tree

9 files changed

+626
-56
lines changed

9 files changed

+626
-56
lines changed
Lines changed: 63 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,63 @@
1+
# 使用医疗领域预训练模型Fine-tune完成中文医疗文本分类任务
2+
3+
近年来,预训练语言模型(Pre-trained Language Model,PLM)逐渐成为自然语言处理(Natural Language Processing,NLP)的主流方法。这类模型可以利用大规模的未标注语料进行训练,得到的模型在下游NLP任务上效果明显提升,在通用领域和特定领域均有广泛应用。在医疗领域,早期的做法是在预先训练好的通用语言模型上进行Fine-tune。后来的研究发现直接使用医疗相关语料学习到的预训练语言模型在医疗文本任务上的效果更好,采用的模型结构也从早期的BERT演变为更新的RoBERTa、ALBERT和ELECTRA。
4+
5+
本示例展示了中文医疗预训练模型eHealth([Building Chinese Biomedical Language Models via Multi-Level Text Discrimination](https://arxiv.org/abs/2110.07244))如何Fine-tune完成中文医疗文本分类任务。
6+
7+
## 模型介绍
8+
9+
本项目针对中文医疗文本分类任务,开源了中文医疗预训练模型eHealth(简写`chinese-ehealth`)。eHealth([Building Chinese Biomedical Language Models via Multi-Level Text Discrimination](https://arxiv.org/abs/2110.07244))使用了医患对话、科普文章、病历档案、临床病理学教材等脱敏中文语料进行预训练,通过预训练任务设计来学习词级别和句级别的文本信息。该模型的整体结构与ELECTRA相似,包括生成器和判别器两部分。 而Fine-tune过程只用到了判别器模块,由12层Transformer网络组成。
10+
11+
## 快速开始
12+
13+
### 代码结构说明
14+
15+
以下是本项目主要代码结构及说明:
16+
17+
```text
18+
sequence_classification/
19+
├── README.md # 使用说明
20+
└── train.py # 训练评估脚本
21+
```
22+
23+
### 模型训练
24+
25+
我们以中文医疗文本数据集CBLUE中的文本分类数据集为示例数据集,包括:
26+
27+
* CHIP-CDN:给定病历档案,预测其中包含的规范化诊断实体。本项目使用了检索后重新构建的二分类数据集,给定病历档案和规范化诊断实体,预测前者是否包含后者(简写`CHIP-CDN-2C`)。
28+
* CHIP-CTC:给定医疗文本描述,按照中国临床筛选标准进行分类。
29+
* CHIP-STS:给定两个涉及5种不同疾病的句子,预测二者语义是否相似。
30+
* KUAKE-QIC:给定医疗问句,对患者咨询目的进行分类。
31+
* KUAKE-QTR:给定医疗问句和文章标题,预测二者内容是否一致。
32+
* KUAKE-QQR:给定两个医疗问句,预测二者描述内容是否一致。
33+
34+
可以运行下边的命令,在训练集上进行训练,并在开发集上进行验证。
35+
```shell
36+
$ unset CUDA_VISIBLE_DEVICES
37+
$ python -m paddle.distributed.launch --gpus "0" train.py --dataset CHIP-CDN-2C --batch_size 256 --max_seq_length 32 --weight_decay 0.01 --warmup_proportion 0.1
38+
```
39+
40+
可支持配置的参数:
41+
42+
* `save_dir`:可选,保存训练模型的目录;默认保存在当前目录checkpoints文件夹下。
43+
* `dataset`:可选,CHIP-CDN-2C CHIP-CTC CHIP-STS KUAKE-QIC KUAKE-QTR KUAKE-QQR,默认为KUAKE-QIC数据集。
44+
* `max_seq_length`:可选,ELECTRA模型使用的最大序列长度,最大不能超过512, 若出现显存不足,请适当调低这一参数;默认为128。
45+
* `batch_size`:可选,批处理大小,请结合显存情况进行调整,若出现显存不足,请适当调低这一参数;默认为32。
46+
* `learning_rate`:可选,Fine-tune的最大学习率;默认为6e-5。
47+
* `weight_decay`:可选,控制正则项力度的参数,用于防止过拟合,默认为0.01。
48+
* `epochs`: 训练轮次,默认为3。
49+
* `valid_steps`: evaluate的间隔steps数,默认100。
50+
* `save_steps`: 保存checkpoints的间隔steps数,默认100。
51+
* `logging_steps`: 日志打印的间隔steps数,默认10。
52+
* `warmup_proption`:可选,学习率warmup策略的比例,如果0.1,则学习率会在前10%训练step的过程中从0慢慢增长到learning_rate, 而后再缓慢衰减,默认为0.1。
53+
* `init_from_ckpt`:可选,模型参数路径,热启动模型训练;默认为None。
54+
* `seed`:可选,随机种子,默认为1000.
55+
* `device`: 选用什么设备进行训练,可选cpu或gpu。如使用gpu训练则参数gpus指定GPU卡号。
56+
* `use_amp`: 是否使用混合精度训练,默认为False。
57+
* `use_ema`: 是否使用Exponential Moving Average预测,默认为False。
58+
59+
### 依赖安装
60+
61+
```shell
62+
pip install xlrd==1.2.0
63+
```
Lines changed: 270 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,270 @@
1+
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
from functools import partial
16+
import argparse
17+
import os
18+
import random
19+
import time
20+
import distutils.util
21+
22+
import numpy as np
23+
import paddle
24+
import paddle.nn.functional as F
25+
from paddle.metric import Accuracy
26+
import paddlenlp as ppnlp
27+
from paddlenlp.data import Stack, Tuple, Pad
28+
from paddlenlp.datasets import load_dataset
29+
from paddlenlp.transformers import LinearDecayWithWarmup
30+
from paddlenlp.metrics import MultiLabelsMetric
31+
from paddlenlp.ops.optimizer import ExponentialMovingAverage
32+
33+
from utils import convert_example
34+
35+
METRIC_CLASSES = {
36+
'KUAKE-QIC': Accuracy,
37+
'KUAKE-QQR': Accuracy,
38+
'KUAKE-QTR': Accuracy,
39+
'CHIP-CTC': partial(
40+
MultiLabelsMetric, name='macro'),
41+
'CHIP-STS': partial(
42+
MultiLabelsMetric, name='macro'),
43+
'CHIP-CDN-2C': partial(
44+
MultiLabelsMetric, name='micro')
45+
}
46+
47+
# yapf: disable
48+
parser = argparse.ArgumentParser()
49+
parser.add_argument('--dataset', choices=['KUAKE-QIC', 'KUAKE-QQR', 'KUAKE-QTR', 'CHIP-STS', 'CHIP-CTC', 'CHIP-CDN-2C'],
50+
default='KUAKE-QIC', type=str, help='Dataset for token classfication tasks.')
51+
parser.add_argument('--seed', default=1000, type=int, help='Random seed for initialization.')
52+
parser.add_argument('--device', choices=['cpu', 'gpu', 'xpu', 'npu'], default='gpu', help='Select which device to train model, default to gpu.')
53+
parser.add_argument('--epochs', default=3, type=int, help='Total number of training epochs to perform.')
54+
parser.add_argument('--batch_size', default=32, type=int, help='Batch size per GPU/CPU for training.')
55+
parser.add_argument('--learning_rate', default=6e-5, type=float, help='Learning rate for fine-tuning sequence classification task.')
56+
parser.add_argument('--weight_decay', default=0.01, type=float, help="Weight decay if we apply some.")
57+
parser.add_argument('--warmup_proportion', default=0.1, type=float, help='Linear warmup proportion over the training process.')
58+
parser.add_argument('--max_seq_length', default=128, type=int, help='The maximum total input sequence length after tokenization.')
59+
parser.add_argument('--init_from_ckpt', default=None, type=str, help='The path of checkpoint to be loaded.')
60+
parser.add_argument('--logging_steps', default=10, type=int, help='The interval steps to logging.')
61+
parser.add_argument('--save_dir', default='./checkpoint', type=str, help='The output directory where the model checkpoints will be written.')
62+
parser.add_argument('--save_steps', default=100, type=int, help='The interval steps to save checkppoints.')
63+
parser.add_argument('--valid_steps', default=100, type=int, help='The interval steps to evaluate model performance.')
64+
parser.add_argument('--use_ema', default=False, type=bool, help='Use exponential moving average for evaluation.')
65+
parser.add_argument('--use_amp', default=False, type=distutils.util.strtobool, help='Enable mixed precision training.')
66+
parser.add_argument('--scale_loss', default=128, type=float, help='The value of scale_loss for fp16.')
67+
68+
args = parser.parse_args()
69+
# yapf: enable
70+
71+
72+
def set_seed(seed):
73+
"""set random seed"""
74+
random.seed(seed)
75+
np.random.seed(seed)
76+
paddle.seed(seed)
77+
78+
79+
@paddle.no_grad()
80+
def evaluate(model, criterion, metric, data_loader):
81+
"""
82+
Given a dataset, it evals model and compute the metric.
83+
84+
Args:
85+
model(obj:`paddle.nn.Layer`): A model to classify texts.
86+
dataloader(obj:`paddle.io.DataLoader`): The dataset loader which generates batches.
87+
criterion(obj:`paddle.nn.Layer`): It can compute the loss.
88+
metric(obj:`paddle.metric.Metric`): The evaluation metric.
89+
"""
90+
model.eval()
91+
metric.reset()
92+
losses = []
93+
for batch in data_loader:
94+
input_ids, token_type_ids, position_ids, labels = batch
95+
logits = model(input_ids, token_type_ids, position_ids)
96+
loss = criterion(logits, labels)
97+
losses.append(loss.numpy())
98+
correct = metric.compute(logits, labels)
99+
metric.update(correct)
100+
if isinstance(metric, Accuracy):
101+
metric_name = 'accuracy'
102+
result = metric.accumulate()
103+
else:
104+
metric_name = metric._name + ' f1'
105+
_, _, result = metric.accumulate(metric._name)
106+
print('eval loss: %.5f, %s: %.5f' % (np.mean(losses), metric_name, result))
107+
model.train()
108+
metric.reset()
109+
110+
111+
def create_dataloader(dataset,
112+
mode='train',
113+
batch_size=1,
114+
batchify_fn=None,
115+
trans_fn=None):
116+
if trans_fn:
117+
dataset = dataset.map(trans_fn)
118+
119+
shuffle = True if mode == 'train' else False
120+
if mode == 'train':
121+
batch_sampler = paddle.io.DistributedBatchSampler(
122+
dataset, batch_size=batch_size, shuffle=shuffle)
123+
else:
124+
batch_sampler = paddle.io.BatchSampler(
125+
dataset, batch_size=batch_size, shuffle=shuffle)
126+
127+
return paddle.io.DataLoader(
128+
dataset=dataset,
129+
batch_sampler=batch_sampler,
130+
collate_fn=batchify_fn,
131+
return_list=True)
132+
133+
134+
def do_train():
135+
paddle.set_device(args.device)
136+
rank = paddle.distributed.get_rank()
137+
if paddle.distributed.get_world_size() > 1:
138+
paddle.distributed.init_parallel_env()
139+
140+
set_seed(args.seed)
141+
142+
train_ds, dev_ds, test_ds = load_dataset(
143+
'cblue', args.dataset, splits=['train', 'dev', 'test'])
144+
145+
model = ppnlp.transformers.ElectraForSequenceClassification.from_pretrained(
146+
'chinese-ehealth', num_classes=len(train_ds.label_list))
147+
tokenizer = ppnlp.transformers.ElectraTokenizer.from_pretrained(
148+
'chinese-ehealth')
149+
150+
trans_func = partial(
151+
convert_example,
152+
tokenizer=tokenizer,
153+
max_seq_length=args.max_seq_length)
154+
batchify_fn = lambda samples, fn=Tuple(
155+
Pad(axis=0, pad_val=tokenizer.pad_token_id), # input
156+
Pad(axis=0, pad_val=tokenizer.pad_token_type_id), # segment
157+
Pad(axis=0, pad_val=args.max_seq_length - 1), # position
158+
Stack(dtype='int64')): [data for data in fn(samples)]
159+
train_data_loader = create_dataloader(
160+
train_ds,
161+
mode='train',
162+
batch_size=args.batch_size,
163+
batchify_fn=batchify_fn,
164+
trans_fn=trans_func)
165+
dev_data_loader = create_dataloader(
166+
dev_ds,
167+
mode='dev',
168+
batch_size=args.batch_size,
169+
batchify_fn=batchify_fn,
170+
trans_fn=trans_func)
171+
172+
if args.init_from_ckpt and os.path.isfile(args.init_from_ckpt):
173+
state_dict = paddle.load(args.init_from_ckpt)
174+
model.set_dict(state_dict)
175+
model = paddle.DataParallel(model)
176+
177+
num_training_steps = len(train_data_loader) * args.epochs
178+
179+
lr_scheduler = LinearDecayWithWarmup(args.learning_rate, num_training_steps,
180+
args.warmup_proportion)
181+
182+
# Generate parameter names needed to perform weight decay.
183+
# All bias and LayerNorm parameters are excluded.
184+
decay_params = [
185+
p.name for n, p in model.named_parameters()
186+
if not any(nd in n for nd in ['bias', 'norm'])
187+
]
188+
189+
optimizer = paddle.optimizer.AdamW(
190+
learning_rate=lr_scheduler,
191+
parameters=model.parameters(),
192+
weight_decay=args.weight_decay,
193+
apply_decay_param_fun=lambda x: x in decay_params)
194+
195+
criterion = paddle.nn.loss.CrossEntropyLoss()
196+
if METRIC_CLASSES[args.dataset] is Accuracy:
197+
metric = METRIC_CLASSES[args.dataset]()
198+
metric_name = 'accuracy'
199+
else:
200+
metric = METRIC_CLASSES[args.dataset](
201+
num_labels=len(train_ds.label_list))
202+
metric_name = metric._name + ' f1'
203+
if args.use_amp:
204+
scaler = paddle.amp.GradScaler(init_loss_scaling=args.scale_loss)
205+
if args.use_ema and rank == 0:
206+
ema = ExponentialMovingAverage(model)
207+
ema.register()
208+
global_step = 0
209+
tic_train = time.time()
210+
total_train_time = 0
211+
for epoch in range(1, args.epochs + 1):
212+
for step, batch in enumerate(train_data_loader, start=1):
213+
input_ids, token_type_ids, position_ids, labels = batch
214+
with paddle.amp.auto_cast(
215+
args.use_amp,
216+
custom_white_list=['layer_norm', 'softmax', 'gelu'], ):
217+
logits = model(input_ids, token_type_ids, position_ids)
218+
loss = criterion(logits, labels)
219+
probs = F.softmax(logits, axis=1)
220+
correct = metric.compute(probs, labels)
221+
metric.update(correct)
222+
223+
if isinstance(metric, Accuracy):
224+
result = metric.accumulate()
225+
else:
226+
_, _, result = metric.accumulate(metric._name)
227+
228+
if args.use_amp:
229+
scaler.scale(loss).backward()
230+
scaler.minimize(optimizer, loss)
231+
else:
232+
loss.backward()
233+
optimizer.step()
234+
lr_scheduler.step()
235+
if args.use_ema and rank == 0:
236+
ema.update()
237+
optimizer.clear_grad()
238+
239+
global_step += 1
240+
if global_step % args.logging_steps == 0 and rank == 0:
241+
time_diff = time.time() - tic_train
242+
total_train_time += time_diff
243+
print(
244+
'global step %d, epoch: %d, batch: %d, loss: %.5f, %s: %.5f, speed: %.2f step/s'
245+
% (global_step, epoch, step, loss, metric_name, result,
246+
args.logging_steps / time_diff))
247+
tic_train = time.time()
248+
249+
if global_step % args.valid_steps == 0 and rank == 0:
250+
if args.use_ema:
251+
ema.apply_shadow()
252+
evaluate(model, criterion, metric, dev_data_loader)
253+
ema.restore()
254+
else:
255+
evaluate(model, criterion, metric, dev_data_loader)
256+
tic_train = time.time()
257+
258+
if global_step % args.save_steps == 0 and rank == 0:
259+
save_dir = os.path.join(args.save_dir, 'model_%d' % global_step)
260+
if not os.path.exists(save_dir):
261+
os.makedirs(save_dir)
262+
model._layers.save_pretrained(save_dir)
263+
tokenizer.save_pretrained(save_dir)
264+
tic_train = time.time()
265+
266+
print('Speed: %.2f steps/s' % (global_step / total_train_time))
267+
268+
269+
if __name__ == "__main__":
270+
do_train()

0 commit comments

Comments
 (0)