Skip to content

Commit 0287e54

Browse files
27182812yingyibiao
andauthored
add model chinesebert (#1100)
* add model chinesebert add model chinesebert paddle model weights: https://drive.google.com/file/d/1Cze9mZGX9_1btGCCYTdb7PA2PNPvw2is/view?usp=sharing * 修改完善 * update * delete logs * delete datasets * delete * update Co-authored-by: yingyibiao <[email protected]>
1 parent e27ac6a commit 0287e54

File tree

16 files changed

+3784
-0
lines changed

16 files changed

+3784
-0
lines changed
Lines changed: 172 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,172 @@
1+
# ChineseBert with PaddleNLP
2+
3+
[ChineseBERT: Chinese Pretraining Enhanced by Glyph and Pinyin Information](https://arxiv.org/pdf/2106.16038.pdf)
4+
5+
**摘要:**
6+
最近的汉语预训练模型忽略了汉语特有的两个重要方面:字形和拼音,它们对语言理解具有重要的语法和语义信息。在本研究中,我们提出了汉语预训练,它将汉字的字形和拼音信息纳入语言模型预训练中。字形嵌入是基于汉字的不同字体获得的,能够从视觉特征中捕捉汉字语义,拼音嵌入代表汉字的发音,处理汉语中高度流行的异义现象(同一汉字具有不同的发音和不同的含义)。在大规模的未标记中文语料库上进行预训练后,所提出的ChineseBERT模型在训练步骤较少的基线模型上产生了显著的性能提高。该模型在广泛的中国自然语言处理任务上实现了新的SOTA性能,包括机器阅读理解、自然语言推理、文本分类、句子对匹配和命名实体识别方面的竞争性能。
7+
8+
本项目是 ChineseBert 在 Paddle 2.x上的开源实现。
9+
10+
## **数据准备**
11+
涉及到的ChnSentiCorp,crmc2018,XNLI数据
12+
部分Paddle已提供,其他可参考https://github.com/27182812/ChineseBERT_paddle,
13+
在data目录下。
14+
15+
16+
## **模型预训练**
17+
模型预训练过程可参考[Electra的README](https://github.com/PaddlePaddle/PaddleNLP/blob/develop/examples/language_model/electra/README.md)
18+
19+
## **Fine-tuning**
20+
21+
### 运行Fine-tuning
22+
23+
#### **使用Paddle提供的预训练模型运行 Fine-tuning**
24+
25+
#### 1、ChnSentiCorp
26+
以ChnSentiCorp数据集为例
27+
28+
#### (1)模型微调:
29+
```shell
30+
# 运行训练
31+
python train_chn.py \
32+
--data_path './data/ChnSentiCorp' \
33+
--device 'gpu' \
34+
--epochs 10 \
35+
--max_seq_length 512 \
36+
--batch_size 8 \
37+
--learning_rate 2e-5 \
38+
--weight_decay 0.0001 \
39+
--warmup_proportion 0.1 \
40+
--seed 2333 \
41+
--save_dir 'outputs/chn' | tee outputs/train_chn.log
42+
```
43+
其中参数释义如下:
44+
- `data_path` 表示微调数据路径
45+
- `device` 表示使用的设备类型。默认为GPU,可以配置为CPU、GPU、XPU。若希望使用多GPU训练,将其设置为GPU,同时环境变量CUDA_VISIBLE_DEVICES配置要使用的GPU id。
46+
- `epochs` 表示训练轮数。
47+
- `max_seq_length` 表示最大句子长度,超过该长度将被截断。
48+
- `batch_size` 表示每次迭代**每张卡**上的样本数目。
49+
- `learning_rate` 表示基础学习率大小,将于learning rate scheduler产生的值相乘作为当前学习率。
50+
- `weight_decay` 表示优化器中使用的weight_decay的系数。
51+
- `warmup_steps` 表示动态学习率热启动的step数。
52+
- `seed` 指定随机种子。
53+
- `save_dir` 表示模型保存路径。
54+
55+
#### (2) 评估
56+
57+
在dev和test数据集上acc分别为95.8和96.08,达到论文精度要求。
58+
59+
#### 2、XNLI
60+
61+
#### (1)训练
62+
63+
```bash
64+
python train_xnli.py \
65+
--data_path './data/XNLI' \
66+
--device 'gpu' \
67+
--epochs 5 \
68+
--max_seq_len 256 \
69+
--batch_size 16 \
70+
--learning_rate 1.3e-5 \
71+
--weight_decay 0.001 \
72+
--warmup_proportion 0.1 \
73+
--seed 2333 \
74+
--save_dir outputs/xnli | tee outputs/train_xnli.log
75+
```
76+
其中参数释义如下:
77+
- `data_path` 表示微调数据路径
78+
- `device` 表示使用的设备类型。默认为GPU,可以配置为CPU、GPU、XPU。若希望使用多GPU训练,将其设置为GPU,同时环境变量CUDA_VISIBLE_DEVICES配置要使用的GPU id。
79+
- `epochs` 表示训练轮数。
80+
- `max_seq_length` 表示最大句子长度,超过该长度将被截断。
81+
- `batch_size` 表示每次迭代**每张卡**上的样本数目。
82+
- `learning_rate` 表示基础学习率大小,将于learning rate scheduler产生的值相乘作为当前学习率。
83+
- `weight_decay` 表示优化器中使用的weight_decay的系数。
84+
- `warmup_steps` 表示动态学习率热启动的step数。
85+
- `seed` 指定随机种子。
86+
- `save_dir` 表示模型保存路径。
87+
88+
#### (2)评估
89+
90+
test数据集 acc最好结果为81.657,达到论文精度要求。
91+
92+
#### 3、cmrc2018
93+
94+
#### (1) 训练
95+
96+
```shell
97+
# 开始训练
98+
python train_cmrc2018.py \
99+
--data_dir "data/cmrc2018" \
100+
--model_name_or_path ChineseBERT-large \
101+
--max_seq_length 512 \
102+
--train_batch_size 8 \
103+
--gradient_accumulation_steps 8 \
104+
--eval_batch_size 16 \
105+
--learning_rate 4e-5 \
106+
--max_grad_norm 1.0 \
107+
--num_train_epochs 3 \
108+
--logging_steps 2 \
109+
--save_steps 20 \
110+
--warmup_radio 0.1 \
111+
--weight_decay 0.01 \
112+
--output_dir outputs/cmrc2018 \
113+
--seed 1111 \
114+
--num_workers 0 \
115+
--use_amp
116+
```
117+
其中参数释义如下:
118+
- `data_path` 表示微调数据路径。
119+
- `model_name_or_path` 模型名称或者路径,支持ChineseBERT-base、ChineseBERT-large两种种规格。
120+
- `max_seq_length` 表示最大句子长度,超过该长度将被截断。
121+
- `train_batch_size` 表示训练过程中每次迭代**每张卡**上的样本数目。
122+
- `gradient_accumulation_steps` 梯度累加步数。
123+
- `eval_batch_size` 表示验证过程中每次迭代**每张卡**上的样本数目。
124+
- `learning_rate` 表示基础学习率大小,将于learning rate scheduler产生的值相乘作为当前学习率。
125+
- `max_grad_norm` 梯度裁剪。
126+
- `num_train_epochs` 表示训练轮数。
127+
- `logging_steps` 表示日志打印间隔。
128+
- `warmup_radio` 表示动态学习率热启动的比例。
129+
- `weight_decay` 表示优化器中使用的weight_decay的系数。
130+
- `output_dir` 表示模型保存路径。
131+
- `seed` 指定随机种子。
132+
- `num_workers` 表示同时工作进程。
133+
- `use_amp` 表示是否使用混合精度。
134+
135+
训练过程中模型会在dev数据集进行评估,其中最好的结果如下所示:
136+
137+
```python
138+
139+
{
140+
AVERAGE = 82.791
141+
F1 = 91.055
142+
EM = 74.526
143+
TOTAL = 3219
144+
SKIP = 0
145+
}
146+
147+
```
148+
149+
#### (2)运行eval_cmrc.py,生成test数据集预测答案
150+
151+
```bash
152+
python eval_cmrc.py --model_name_or_path outputs/step-340 --n_best_size 35 --max_answer_length 65
153+
```
154+
155+
其中,model_name_or_path为模型路径
156+
157+
#### (3)提交CLUE
158+
159+
test数据集 EM为78.55,达到论文精度要求
160+
161+
162+
## Reference
163+
164+
```bibtex
165+
@article{sun2021chinesebert,
166+
title={ChineseBERT: Chinese Pretraining Enhanced by Glyph and Pinyin Information},
167+
author={Sun, Zijun and Li, Xiaoya and Sun, Xiaofei and Meng, Yuxian and Ao, Xiang and He, Qing and Wu, Fei and Li, Jiwei},
168+
journal={arXiv preprint arXiv:2106.16038},
169+
year={2021}
170+
}
171+
172+
```
Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
python eval.py --model_name_or_path outputs/cmrc2018/step-140 --n_best_size 35 --max_answer_length 65
Lines changed: 190 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,190 @@
1+
#encoding=utf8
2+
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
3+
4+
# Licensed under the Apache License, Version 2.0 (the "License");
5+
# you may not use this file except in compliance with the License.
6+
# You may obtain a copy of the License at
7+
8+
# http://www.apache.org/licenses/LICENSE-2.0
9+
10+
# Unless required by applicable law or agreed to in writing, software
11+
# distributed under the License is distributed on an "AS IS" BASIS,
12+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
# See the License for the specific language governing permissions and
14+
# limitations under the License.
15+
'''
16+
Evaluation script for CMRC 2018
17+
version: v5 - special
18+
Note:
19+
v5 - special: Evaluate on SQuAD-style CMRC 2018 Datasets
20+
v5: formatted output, add usage description
21+
v4: fixed segmentation issues
22+
'''
23+
24+
import argparse
25+
import json
26+
import re
27+
import sys
28+
from collections import OrderedDict
29+
import nltk
30+
31+
32+
# split Chinese with English
33+
def mixed_segmentation(in_str, rm_punc=False):
34+
in_str = str(in_str).lower().strip()
35+
segs_out = []
36+
temp_str = ""
37+
sp_char = [
38+
'-', ':', '_', '*', '^', '/', '\\', '~', '`', '+', '=', ',', '。', ':',
39+
'?', '!', '“', '”', ';', '’', '《', '》', '……', '·', '、', '「', '」', '(',
40+
')', '-', '~', '『', '』'
41+
]
42+
for char in in_str:
43+
if rm_punc and char in sp_char:
44+
continue
45+
if re.search(r'[\u4e00-\u9fa5]', char) or char in sp_char:
46+
if temp_str != "":
47+
ss = nltk.word_tokenize(temp_str)
48+
segs_out.extend(ss)
49+
temp_str = ""
50+
segs_out.append(char)
51+
else:
52+
temp_str += char
53+
54+
# handling last part
55+
if temp_str != "":
56+
ss = nltk.word_tokenize(temp_str)
57+
segs_out.extend(ss)
58+
59+
return segs_out
60+
61+
62+
# remove punctuation
63+
def remove_punctuation(in_str):
64+
in_str = str(in_str).lower().strip()
65+
sp_char = [
66+
'-', ':', '_', '*', '^', '/', '\\', '~', '`', '+', '=', ',', '。', ':',
67+
'?', '!', '“', '”', ';', '’', '《', '》', '……', '·', '、', '「', '」', '(',
68+
')', '-', '~', '『', '』'
69+
]
70+
out_segs = []
71+
for char in in_str:
72+
if char in sp_char:
73+
continue
74+
else:
75+
out_segs.append(char)
76+
return ''.join(out_segs)
77+
78+
79+
# find longest common string
80+
def find_lcs(s1, s2):
81+
m = [[0 for i in range(len(s2) + 1)] for j in range(len(s1) + 1)]
82+
mmax = 0
83+
p = 0
84+
for i in range(len(s1)):
85+
for j in range(len(s2)):
86+
if s1[i] == s2[j]:
87+
m[i + 1][j + 1] = m[i][j] + 1
88+
if m[i + 1][j + 1] > mmax:
89+
mmax = m[i + 1][j + 1]
90+
p = i + 1
91+
return s1[p - mmax:p], mmax
92+
93+
94+
#
95+
def evaluate(ground_truth_file, prediction_file):
96+
f1 = 0
97+
em = 0
98+
total_count = 0
99+
skip_count = 0
100+
for instance in ground_truth_file["data"]:
101+
# context_id = instance['context_id'].strip()
102+
# context_text = instance['context_text'].strip()
103+
for para in instance["paragraphs"]:
104+
for qas in para['qas']:
105+
total_count += 1
106+
query_id = qas['id'].strip()
107+
query_text = qas['question'].strip()
108+
answers = [x["text"] for x in qas['answers']]
109+
110+
if query_id not in prediction_file:
111+
sys.stderr.write('Unanswered question: {}\n'.format(
112+
query_id))
113+
skip_count += 1
114+
continue
115+
116+
prediction = str(prediction_file[query_id])
117+
f1 += calc_f1_score(answers, prediction)
118+
em += calc_em_score(answers, prediction)
119+
120+
f1_score = 100.0 * f1 / total_count
121+
em_score = 100.0 * em / total_count
122+
return f1_score, em_score, total_count, skip_count
123+
124+
125+
def calc_f1_score(answers, prediction):
126+
f1_scores = []
127+
for ans in answers:
128+
ans_segs = mixed_segmentation(ans, rm_punc=True)
129+
prediction_segs = mixed_segmentation(prediction, rm_punc=True)
130+
lcs, lcs_len = find_lcs(ans_segs, prediction_segs)
131+
if lcs_len == 0:
132+
f1_scores.append(0)
133+
continue
134+
precision = 1.0 * lcs_len / len(prediction_segs)
135+
recall = 1.0 * lcs_len / len(ans_segs)
136+
f1 = (2 * precision * recall) / (precision + recall)
137+
f1_scores.append(f1)
138+
return max(f1_scores)
139+
140+
141+
def calc_em_score(answers, prediction):
142+
em = 0
143+
for ans in answers:
144+
ans_ = remove_punctuation(ans)
145+
prediction_ = remove_punctuation(prediction)
146+
if ans_ == prediction_:
147+
em = 1
148+
break
149+
return em
150+
151+
152+
def get_result(ground_truth_file, prediction_file):
153+
ground_truth_file = json.load(open(ground_truth_file, 'rb'))
154+
prediction_file = json.load(open(prediction_file, 'rb'))
155+
F1, EM, TOTAL, SKIP = evaluate(ground_truth_file, prediction_file)
156+
AVG = (EM + F1) * 0.5
157+
output_result = OrderedDict()
158+
output_result['AVERAGE'] = '%.3f' % AVG
159+
output_result['F1'] = '%.3f' % F1
160+
output_result['EM'] = '%.3f' % EM
161+
output_result['TOTAL'] = TOTAL
162+
output_result['SKIP'] = SKIP
163+
print(json.dumps(output_result))
164+
return output_result
165+
166+
167+
if __name__ == '__main__':
168+
parser = argparse.ArgumentParser(
169+
description='Evaluation Script for CMRC 2018')
170+
parser.add_argument(
171+
'--dataset_file',
172+
default="cmrc2018_public/dev.json",
173+
help='Official dataset file')
174+
parser.add_argument(
175+
'--prediction_file',
176+
default="all_predictions.json",
177+
help='Your prediction File')
178+
args = parser.parse_args()
179+
ground_truth_file = json.load(open(args.dataset_file, 'rb'))
180+
prediction_file = json.load(open(args.prediction_file, 'rb'))
181+
F1, EM, TOTAL, SKIP = evaluate(ground_truth_file, prediction_file)
182+
AVG = (EM + F1) * 0.5
183+
output_result = OrderedDict()
184+
output_result['AVERAGE'] = '%.3f' % AVG
185+
output_result['F1'] = '%.3f' % F1
186+
output_result['EM'] = '%.3f' % EM
187+
output_result['TOTAL'] = TOTAL
188+
output_result['SKIP'] = SKIP
189+
output_result['FILE'] = args.prediction_file
190+
print(json.dumps(output_result))

0 commit comments

Comments
 (0)