Skip to content

Commit ff02709

Browse files
committed
[ehealth] draft version of ner and spo tasks
1 parent ba3ea1c commit ff02709

File tree

6 files changed

+722
-35
lines changed

6 files changed

+722
-35
lines changed

examples/biomedical/cblue/model.py

Lines changed: 118 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,118 @@
1+
import paddle
2+
import paddle.nn as nn
3+
from paddlenlp.transformers import ElectraPretrainedModel
4+
5+
6+
class ElectraForBinaryTokenClassification(ElectraPretrainedModel):
7+
"""
8+
Electra Model with two linear layers on top of the hidden-states output layers,
9+
designed for token classification tasks with nesting.
10+
11+
Args:
12+
electra (:class:`ElectraModel`):
13+
An instance of ElectraModel.
14+
num_classes (list):
15+
The number of classes.
16+
use_crf (bool, optional):
17+
Use conditional random fields for named entity recognition.
18+
Defaults to False.
19+
dropout (float, optionl):
20+
The dropout probability for output of Electra.
21+
If None, use the same value as `hidden_dropout_prob' of 'ElectraModel`
22+
instance `electra`. Defaults to None.
23+
"""
24+
25+
def __init__(self, electra, num_classes, dropout=None):
26+
super(ElectraForBinaryTokenClassification, self).__init__()
27+
assert (len(num_classes) == 2)
28+
self.num_classes_oth = num_classes[0]
29+
self.num_classes_sym = num_classes[1]
30+
self.electra = electra
31+
self.dropout = nn.Dropout(dropout if dropout is not None else
32+
self.electra.config['hidden_dropout_prob'])
33+
self.classifier_oth = nn.Linear(self.electra.config['hidden_size'],
34+
self.num_classes_oth)
35+
self.classifier_sym = nn.Linear(self.electra.config['hidden_size'],
36+
self.num_classes_sym)
37+
self.init_weights()
38+
39+
def forward(self,
40+
input_ids=None,
41+
token_type_ids=None,
42+
position_ids=None,
43+
attention_mask=None):
44+
r"""
45+
The ElectraForMedicalClassification forward method, overrides the __call__() special method.
46+
47+
TODO
48+
"""
49+
sequence_output = self.electra(input_ids, token_type_ids, position_ids,
50+
attention_mask)
51+
sequence_output = self.dropout(sequence_output)
52+
53+
logits_sym = self.classifier_sym(sequence_output)
54+
logits_oth = self.classifier_oth(sequence_output)
55+
return logits_oth, logits_sym
56+
57+
58+
class MultiHeadAttentionForSPO(nn.Layer):
59+
def __init__(self, embed_dim, num_heads, scale_value=768):
60+
super(MultiHeadAttentionForSPO, self).__init__()
61+
self.embed_dim = embed_dim
62+
self.num_heads = num_heads
63+
self.scale_value = scale_value**-0.5
64+
self.q_proj = nn.Linear(embed_dim, embed_dim * num_heads)
65+
self.k_proj = nn.Linear(embed_dim, embed_dim * num_heads)
66+
67+
def forward(self, query, key):
68+
q = self.q_proj(query)
69+
k = self.k_proj(query)
70+
q = paddle.reshape(q, shape=[0, 0, self.num_heads, self.embed_dim])
71+
k = paddle.reshape(k, shape=[0, 0, self.num_heads, self.embed_dim])
72+
q = paddle.transpose(q, perm=[0, 2, 1, 3])
73+
k = paddle.transpose(k, perm=[0, 2, 1, 3])
74+
scores = paddle.matmul(q, k, transpose_y=True)
75+
scores = paddle.scale(scores, scale=self.scale_value)
76+
return scores
77+
78+
79+
class ElectraForSPO(ElectraPretrainedModel):
80+
"""
81+
"""
82+
83+
def __init__(self, electra, num_classes, dropout=None):
84+
super(ElectraForSPO, self).__init__()
85+
self.num_classes = num_classes
86+
self.electra = electra
87+
self.dropout = nn.Dropout(dropout if dropout is not None else
88+
self.electra.config['hidden_dropout_prob'])
89+
self.classifier = nn.Linear(self.electra.config['hidden_size'], 2)
90+
self.span_attention = MultiHeadAttentionForSPO(
91+
self.electra.config['hidden_size'], num_classes)
92+
self.sigmoid = paddle.nn.Sigmoid()
93+
94+
def forward(self,
95+
input_ids=None,
96+
token_type_ids=None,
97+
position_ids=None,
98+
attention_mask=None):
99+
sequence_outputs, _, all_hidden_states = self.electra(
100+
input_ids,
101+
token_type_ids,
102+
position_ids,
103+
attention_mask,
104+
output_hidden_states=True)
105+
sequence_outputs = self.dropout(sequence_outputs)
106+
ent_logits = self.classifier(sequence_outputs)
107+
108+
subject_output = all_hidden_states[-2]
109+
cls_output = paddle.unsqueeze(sequence_outputs[:, 0, :], axis=1)
110+
subject_output = subject_output + cls_output
111+
112+
output_size = self.num_classes + self.electra.config['hidden_size']
113+
rel_logits = self.span_attention(sequence_outputs, subject_output)
114+
115+
ent_logits = self.sigmoid(ent_logits)
116+
rel_logits = self.sigmoid(rel_logits)
117+
118+
return ent_logits, rel_logits

examples/biomedical/cblue/train_classification.py

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@
2626
import paddlenlp as ppnlp
2727
from paddlenlp.data import Stack, Tuple, Pad
2828
from paddlenlp.datasets import load_dataset
29-
from paddlenlp.transformers import ElectraForSequenceClassification, LinearDecayWithWarmup
29+
from paddlenlp.transformers import ElectraForSequenceClassification, ElectraTokenizer, LinearDecayWithWarmup
3030
from paddlenlp.metrics import MultiLabelsMetric, AccuracyAndF1
3131
from paddlenlp.ops.optimizer import ExponentialMovingAverage
3232

@@ -117,16 +117,15 @@ def do_train():
117117

118118
set_seed(args.seed)
119119

120-
train_ds, dev_ds, test_ds = load_dataset(
121-
'cblue', args.dataset, splits=['train', 'dev', 'test'])
120+
train_ds, dev_ds = load_dataset(
121+
'cblue', args.dataset, splits=['train', 'dev'])
122122

123123
model = ElectraForSequenceClassification.from_pretrained(
124124
'ehealth-chinese',
125125
num_classes=len(train_ds.label_list),
126126
activation='tanh',
127127
layer_norm_eps=1e-5)
128-
tokenizer = ppnlp.transformers.ElectraTokenizer.from_pretrained(
129-
'ehealth-chinese')
128+
tokenizer = ElectraTokenizer.from_pretrained('ehealth-chinese')
130129

131130
trans_func = partial(
132131
convert_example,
Lines changed: 228 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,228 @@
1+
from functools import partial
2+
import argparse
3+
import os
4+
import random
5+
import time
6+
import distutils.util
7+
8+
import numpy as np
9+
import paddle
10+
from paddlenlp.data import Pad, Dict
11+
from paddlenlp.datasets import load_dataset
12+
from paddlenlp.transformers import LinearDecayWithWarmup, ElectraTokenizer
13+
from paddlenlp.metrics import ChunkEvaluator
14+
15+
from model import ElectraForBinaryTokenClassification
16+
from utils import create_dataloader, convert_example_ner
17+
18+
# yapf: disable
19+
parser = argparse.ArgumentParser()
20+
parser.add_argument('--device', choices=['cpu', 'gpu', 'xpu', 'npu'], default='gpu', help='Select which device to train model, default to gpu.')
21+
parser.add_argument('--init_from_ckpt', default=None, type=str, help='The path of checkpoint to be loaded.')
22+
parser.add_argument('--batch_size', default=8, type=int, help='Batch size per GPU/CPU for training.')
23+
parser.add_argument('--learning_rate', default=6e-5, type=float, help='Learning rate for fine-tuning token classification task.')
24+
parser.add_argument('--max_seq_length', default=128, type=int, help='The maximum total input sequence length after tokenization.')
25+
parser.add_argument('--valid_steps', default=100, type=int, help='The interval steps to evaluate model performance.')
26+
parser.add_argument('--logging_steps', default=10, type=int, help='The interval steps to logging.')
27+
parser.add_argument('--save_steps', default=10000, type=int, help='The interval steps to save checkpoints.')
28+
parser.add_argument('--weight_decay', default=0.01, type=float, help='Weight decay if we apply some.')
29+
parser.add_argument('--warmup_proportion', default=0.1, type=float, help='Linear warmup proportion over the training process.')
30+
parser.add_argument('--use_amp', default=False, type=bool, help='Enable mixed precision training.')
31+
parser.add_argument('--epochs', default=1, type=int, help='Total number of training epochs.')
32+
parser.add_argument('--eval_mention', default=True, type=bool, help='.')
33+
parser.add_argument('--update_tokenizer', default=True, type=bool, help='Update the word tokenizer during training.')
34+
parser.add_argument('--seed', default=1000, type=int, help='Random seed.')
35+
parser.add_argument('--save_dir', default='./checkpoint', type=str, help='The output directory where the model checkpoints will be written.')
36+
37+
args = parser.parse_args()
38+
# yapf: enable
39+
40+
41+
def set_seed(seed):
42+
"""set random seed"""
43+
random.seed(seed)
44+
np.random.seed(seed)
45+
paddle.seed(seed)
46+
47+
48+
@paddle.no_grad()
49+
def evaluate(model, criterion, metrics, data_loader):
50+
model.eval()
51+
metrics[0].reset()
52+
losses = []
53+
for batch in data_loader:
54+
input_ids, token_type_ids, position_ids, masks, label_oth, label_sym = batch
55+
logits = model(input_ids, token_type_ids, position_ids)
56+
loss_oth = criterion(logits[0], paddle.unsqueeze(label_oth, 2))
57+
loss_oth = paddle.mean(loss_oth * paddle.unsqueeze(masks, 2))
58+
loss_sym = criterion(logits[1], paddle.unsqueeze(label_sym, 2))
59+
loss_sym = paddle.mean(loss_sym * paddle.unsqueeze(masks, 2))
60+
61+
losses.append([loss_oth.numpy(), loss_sym.numpy()])
62+
63+
lengths = paddle.sum(masks, axis=1)
64+
pred_oth = paddle.argmax(logits[0], axis=2)
65+
pred_sym = paddle.argmax(logits[1], axis=2)
66+
correct_oth = metrics[0].compute(lengths, pred_oth, label_oth)
67+
correct_sym = metrics[1].compute(lengths, pred_sym, label_sym)
68+
correct_oth = [x.numpy() for x in correct_oth]
69+
correct_sym = [x.numpy() for x in correct_sym]
70+
metrics[0].update(*correct_oth)
71+
metrics[0].update(*correct_sym)
72+
_, _, result = metrics[0].accumulate()
73+
loss = np.mean(losses, axis=0)
74+
print('eval loss symptom: %.5f, loss others: %.5f, f1: %.5f' %
75+
(loss[1], loss[0], result))
76+
model.train()
77+
metrics[0].reset()
78+
79+
80+
def do_train():
81+
paddle.set_device(args.device)
82+
rank = paddle.distributed.get_rank()
83+
if paddle.distributed.get_world_size() > 1:
84+
paddle.distributed.init_parallel_env()
85+
86+
set_seed(args.seed)
87+
88+
train_ds, dev_ds = load_dataset('cblue', 'CMeEE', splits=['train', 'dev'])
89+
90+
model = ElectraForBinaryTokenClassification.from_pretrained(
91+
'ehealth-chinese', num_classes=[len(x) for x in train_ds.label_list])
92+
tokenizer = ElectraTokenizer.from_pretrained('ehealth-chinese')
93+
94+
label_list = train_ds.label_list
95+
pad_label_id = [len(label_list[0]) - 1, len(label_list[1]) - 1]
96+
ignore_label_id = -100
97+
98+
trans_func = partial(
99+
convert_example_ner,
100+
tokenizer=tokenizer,
101+
max_seq_length=args.max_seq_length,
102+
pad_label_id=pad_label_id)
103+
104+
batchify_fn = lambda samples, fn=Dict({
105+
'input_ids': Pad(axis=0, pad_val=tokenizer.pad_token_id, dtype='int64'),
106+
'token_type_ids': Pad(axis=0, pad_val=tokenizer.pad_token_type_id, dtype='int64'),
107+
'position_ids': Pad(axis=0, pad_val=tokenizer.pad_token_id, dtype='int64'),
108+
'mask': Pad(axis=0, pad_val=0, dtype='float32'),
109+
'label_oth': Pad(axis=0, pad_val=pad_label_id[0], dtype='int64'),
110+
'label_sym': Pad(axis=0, pad_val=pad_label_id[1], dtype='int64')
111+
}): fn(samples)
112+
113+
train_data_loader = create_dataloader(
114+
train_ds,
115+
mode='train',
116+
batch_size=args.batch_size,
117+
batchify_fn=batchify_fn,
118+
trans_fn=trans_func)
119+
120+
dev_data_loader = create_dataloader(
121+
dev_ds,
122+
mode='dev',
123+
batch_size=args.batch_size,
124+
batchify_fn=batchify_fn,
125+
trans_fn=trans_func)
126+
127+
if args.init_from_ckpt:
128+
if not os.path.isfile(args.init_from_ckpt):
129+
raise ValueError('init_from_ckpt is not a valid model filename.')
130+
state_dict = paddle.load(args.init_from_ckpt)
131+
model.set_dict(state_dict)
132+
if paddle.distributed.get_world_size() > 1:
133+
model = paddle.DataParallel(model)
134+
135+
num_training_steps = len(train_data_loader) * args.epochs
136+
137+
lr_scheduler = LinearDecayWithWarmup(args.learning_rate, num_training_steps,
138+
args.warmup_proportion)
139+
140+
decay_params = [
141+
p.name for n, p in model.named_parameters()
142+
if not any(nd in n for nd in ['bias', 'norm'])
143+
]
144+
145+
optimizer = paddle.optimizer.AdamW(
146+
learning_rate=lr_scheduler,
147+
parameters=model.parameters(),
148+
weight_decay=args.weight_decay,
149+
apply_decay_param_fun=lambda x: x in decay_params)
150+
151+
criterion = paddle.nn.functional.softmax_with_cross_entropy
152+
153+
metrics = [ChunkEvaluator(label_list[0]), ChunkEvaluator(label_list[1])]
154+
155+
if args.use_amp:
156+
scaler = paddle.amp.GradScaler(init_loss_scaling=args.scale_loss)
157+
158+
global_step = 0
159+
tic_train = time.time()
160+
total_train_time = 0
161+
for epoch in range(1, args.epochs + 1):
162+
for step, batch in enumerate(train_data_loader, start=1):
163+
input_ids, token_type_ids, position_ids, masks, label_oth, label_sym = batch
164+
with paddle.amp.auto_cast(
165+
args.use_amp,
166+
custom_white_list=['layer_norm', 'softmax', 'gelu'], ):
167+
att_mask = paddle.unsqueeze(masks, axis=2)
168+
att_mask = paddle.matmul(att_mask, att_mask, transpose_y=True)
169+
logits = model(input_ids, token_type_ids, position_ids, masks)
170+
171+
loss_oth = criterion(logits[0], paddle.unsqueeze(label_oth, 2))
172+
loss_sym = criterion(logits[1], paddle.unsqueeze(label_sym, 2))
173+
loss_masks = paddle.unsqueeze(masks, 2)
174+
loss_oth = paddle.mean(loss_oth * loss_masks)
175+
loss_sym = paddle.mean(loss_sym * loss_masks)
176+
177+
loss = loss_oth + loss_sym
178+
179+
lengths = paddle.sum(masks, axis=1)
180+
pred_oth = paddle.argmax(logits[0], axis=-1)
181+
pred_sym = paddle.argmax(logits[1], axis=-1)
182+
correct_oth = metrics[0].compute(lengths, pred_oth, label_oth)
183+
correct_sym = metrics[1].compute(lengths, pred_sym, label_sym)
184+
correct_oth = [x.numpy() for x in correct_oth]
185+
correct_sym = [x.numpy() for x in correct_sym]
186+
metrics[0].update(*correct_oth)
187+
metrics[0].update(*correct_sym)
188+
_, _, f1 = metrics[0].accumulate()
189+
190+
if args.use_amp:
191+
scaler.scale(loss).backward()
192+
scaler.minimize(optimizer, loss)
193+
else:
194+
loss.backward()
195+
optimizer.step()
196+
lr_scheduler.step()
197+
optimizer.clear_grad()
198+
199+
global_step += 1
200+
if global_step % args.logging_steps == 0 and rank == 0:
201+
time_diff = time.time() - tic_train
202+
total_train_time += time_diff
203+
print(
204+
'global step %d, epoch: %d, batch: %d, loss: %.5f, loss symptom: %.5f, loss others: %.5f, f1: %.5f, speed: %.2f step/s'
205+
% (global_step, epoch, step, loss, loss_sym, loss_oth,
206+
f1, args.logging_steps / time_diff))
207+
tic_train = time.time()
208+
209+
if global_step % args.valid_steps == 0 and rank == 0:
210+
evaluate(model, criterion, metrics, dev_data_loader)
211+
tic_train = time.time()
212+
213+
if global_step % args.save_steps == 0 and rank == 0:
214+
save_dir = os.patj.join(args.save_dir,
215+
'model_%d' % global_step)
216+
if not os.path.exists(save_dir):
217+
os.makedirs(save_dir)
218+
if paddle.distributed.get_world_size() > 1:
219+
model._layers.save_pretrained(save_dir)
220+
else:
221+
model.save_pretrained(save_dir)
222+
tokenizer.save_pretrained(save_dir)
223+
tic_train = time.time()
224+
print('Speed: %.2f steps/s' % (global_step / total_train_time))
225+
226+
227+
if __name__ == '__main__':
228+
do_train()

0 commit comments

Comments
 (0)