Skip to content

Commit 5c58459

Browse files
tianxinZeyuChen
andauthored
Add Question matching baseline (#874)
* finish QM baseline * update * update * update * add README.md * Update README.md * implement rdrop as paddlenlp api * add rdrop.py into paddlenlp.losses module Co-authored-by: Zeyu Chen <[email protected]>
1 parent 7e098f1 commit 5c58459

File tree

9 files changed

+673
-1
lines changed

9 files changed

+673
-1
lines changed
Lines changed: 126 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,126 @@
1+
# 千言-问题匹配鲁棒性评测基线
2+
3+
我们基于预训练模型 ERNIE-Gram 在[千言-问题匹配鲁棒性评测竞赛]() 建立了 Baseline 方案和评测结果.
4+
5+
## 评测效果
6+
本项目分别基于ERNIE-1.0、Bert-base-chinese、ERNIE-Gram 3 个中文预训练模型训练了单塔 Point-wise 的匹配模型, 基于 ERNIE-Gram 的模型效果显著优于其它 2 个预训练模型。
7+
此外,在 ERNIE-Gram 模型基础上我们也对最新的正则化策略 [R-Drop](https://arxiv.org/abs/2106.14448) 进行了相关评测, [R-Drop](https://arxiv.org/abs/2106.14448) 策略的核心思想是针对同 1 个训练样本过多次前向网络得到的输出加上正则化的 Loss 约束。
8+
9+
| 模型 | rdrop_coef | dev acc | test-A acc | test-B acc|
10+
| ---- | ---- |-----|--------|------- |
11+
| ernie-1.0-base |0.0| 86.96 |76.20 | 77.50|
12+
| bert-base-chinese |0.0| 86.93| 76.90 |77.60 |
13+
| ernie-gram-zh | 0.0 |87.66 | **80.80** | **81.20** |
14+
| ernie-gram-zh | 0.1 |87.91 | 80.20 | 80.80 |
15+
| ernie-gram-zh | 0.2 |87.47 | 80.10 | 81.00 |
16+
17+
18+
## 快速开始
19+
20+
### 代码结构说明
21+
22+
以下是本项目主要代码结构及说明:
23+
```
24+
question_matching/
25+
├── model.py # 匹配模型组网
26+
├── data.py # 训练样本的数据读取、转换逻辑
27+
├── predict.py # 模型预测脚本,输出测试集的预测结果: 0,1
28+
└── train.py # 模型训练评估
29+
```
30+
31+
### 数据准备
32+
本项目使用竞赛提供的 LCQMC、BQ、OPPO 这 3 个数据集的训练集合集作为训练集,使用这 3 个数据集的验证集合集作为验证集。
33+
34+
运行如下命令生成本项目所使用的训练集和验证集,您在参赛过程中可以探索采取其它的训练集和验证集组合,不需要和基线方案完全一致。
35+
```shell
36+
cat ./data/train/LCQMC/train ./data/train/BQ/train ./data/train/OPPO/train > train.txt
37+
cat ./data/train/LCQMC/dev ./data/train/BQ/dev ./data/train/OPPO/dev > dev.txt
38+
```
39+
训练集数据格式为 3 列: text_a \t text_b \t label, 样例数据如下:
40+
```text
41+
喜欢打篮球的男生喜欢什么样的女生 爱打篮球的男生喜欢什么样的女生 1
42+
我手机丢了,我想换个手机 我想买个新手机,求推荐 1
43+
大家觉得她好看吗 大家觉得跑男好看吗? 0
44+
求秋色之空漫画全集 求秋色之空全集漫画 1
45+
晚上睡觉带着耳机听音乐有什么害处吗? 孕妇可以戴耳机听音乐吗? 0
46+
```
47+
验证集的数据格式和训练集相同,样例如下:
48+
```
49+
开初婚未育证明怎么弄? 初婚未育情况证明怎么开? 1
50+
谁知道她是网络美女吗? 爱情这杯酒谁喝都会醉是什么歌 0
51+
男孩喝女孩的尿的故事 怎样才知道是生男孩还是女孩 0
52+
这种图片是用什么软件制作的? 这种图片制作是用什么软件呢? 1
53+
```
54+
55+
### 模型训练
56+
运行如下命令,即可复现本项目中基于 ERNIE-Gram 的基线模型:
57+
58+
```shell
59+
$unset CUDA_VISIBLE_DEVICES
60+
python -u -m paddle.distributed.launch --gpus "0,1,2,3" train.py \
61+
--train_set train.txt \
62+
--dev_set dev.txt \
63+
--device gpu \
64+
--eval_step 100 \
65+
--save_dir ./checkpoints \
66+
--train_batch_size 32 \
67+
--learning_rate 2E-5 \
68+
--rdrop_coef 0.0
69+
```
70+
71+
可支持配置的参数:
72+
* `train_set`: 训练集的文件。
73+
* `dev_set`:验证集数据文件。
74+
* `rdrop_coef`:可选,控制 R-Drop 策略正则化 KL-Loss 的系数;默认为 0.0, 即不使用 R-Drop 策略。
75+
* `train_batch_size`:可选,批处理大小,请结合显存情况进行调整,若出现显存不足,请适当调低这一参数;默认为32。
76+
* `learning_rate`:可选,Fine-tune的最大学习率;默认为5e-5。
77+
* `weight_decay`:可选,控制正则项力度的参数,用于防止过拟合,默认为0.0。
78+
* `epochs`: 训练轮次,默认为3。
79+
* `warmup_proption`:可选,学习率 warmup 策略的比例,如果 0.1,则学习率会在前 10% 训练 step 的过程中从 0 慢慢增长到 learning_rate, 而后再缓慢衰减,默认为 0.0。
80+
* `init_from_ckpt`:可选,模型参数路径,热启动模型训练;默认为None。
81+
* `seed`:可选,随机种子,默认为1000。
82+
* `device`: 选用什么设备进行训练,可选cpu或gpu。如使用gpu训练则参数gpus指定GPU卡号。
83+
84+
程序运行时将会自动进行训练,评估。同时训练过程中会自动保存模型在指定的`save_dir`中。
85+
86+
训练过程中每一次在验证集上进行评估之后,程序会根据验证集的评估指标是否优于之前最优的模型指标来决定是否存储当前模型,如果优于之前最优的验证集指标则会存储当前模型,否则则不存储,因此训练过程结束之后,模型存储路径下 step 数最大的模型则对应验证集指标最高的模型, 一般我们选择验证集指标最高的模型进行预测。
87+
88+
如:
89+
```text
90+
checkpoints/
91+
├── model_10000
92+
│   ├── model_state.pdparams
93+
│   ├── tokenizer_config.json
94+
│   └── vocab.txt
95+
└── ...
96+
```
97+
98+
**NOTE:**
99+
* 如需恢复模型训练,则可以设置`init_from_ckpt`, 如`init_from_ckpt=checkpoints/model_100/model_state.pdparams`
100+
101+
102+
### 开始预测
103+
训练完成后,在指定的 checkpoints 路径下会自动存储在验证集评估指标最高的模型,运行如下命令开始生成预测结果:
104+
```shell
105+
$ unset CUDA_VISIBLE_DEVICES
106+
python -u \
107+
predict.py \
108+
--device gpu \
109+
--params_path "./checkpoints/model_10000/model_state.pdparams" \
110+
--batch_size 128 \
111+
--input_file "${test_set}" \
112+
--result_file "predict_result"
113+
```
114+
115+
输出预测结果示例如下:
116+
```text
117+
0
118+
1
119+
0
120+
1
121+
```
122+
### 提交进行评测
123+
提交预测结果进行评测
124+
125+
## Reference
126+
[1] Liang, Xiaobo, Lijun Wu, Juntao Li, Yue Wang, Qi Meng, Tao Qin, Wei Chen, Min Zhang, and Tie-Yan Liu. “R-Drop: Regularized Dropout for Neural Networks.” ArXiv:2106.14448 [Cs], June 28, 2021. http://arxiv.org/abs/2106.14448.
Lines changed: 74 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,74 @@
1+
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
import paddle
16+
import numpy as np
17+
18+
from paddlenlp.datasets import MapDataset
19+
20+
21+
def create_dataloader(dataset,
22+
mode='train',
23+
batch_size=1,
24+
batchify_fn=None,
25+
trans_fn=None):
26+
if trans_fn:
27+
dataset = dataset.map(trans_fn)
28+
29+
shuffle = True if mode == 'train' else False
30+
if mode == 'train':
31+
batch_sampler = paddle.io.DistributedBatchSampler(
32+
dataset, batch_size=batch_size, shuffle=shuffle)
33+
else:
34+
batch_sampler = paddle.io.BatchSampler(
35+
dataset, batch_size=batch_size, shuffle=shuffle)
36+
37+
return paddle.io.DataLoader(
38+
dataset=dataset,
39+
batch_sampler=batch_sampler,
40+
collate_fn=batchify_fn,
41+
return_list=True)
42+
43+
44+
def read_text_pair(data_path, is_test=False):
45+
"""Reads data."""
46+
with open(data_path, 'r', encoding='utf-8') as f:
47+
for line in f:
48+
data = line.rstrip().split("\t")
49+
if is_test == False:
50+
if len(data) != 3:
51+
continue
52+
yield {'query1': data[0], 'query2': data[1], 'label': data[2]}
53+
else:
54+
if len(data) != 2:
55+
continue
56+
yield {'query1': data[0], 'query2': data[1]}
57+
58+
59+
60+
def convert_example(example, tokenizer, max_seq_length=512, is_test=False):
61+
62+
query, title = example["query1"], example["query2"]
63+
64+
encoded_inputs = tokenizer(
65+
text=query, text_pair=title, max_seq_len=max_seq_length)
66+
67+
input_ids = encoded_inputs["input_ids"]
68+
token_type_ids = encoded_inputs["token_type_ids"]
69+
70+
if not is_test:
71+
label = np.array([example["label"]], dtype="int64")
72+
return input_ids, token_type_ids, label
73+
else:
74+
return input_ids, token_type_ids
Lines changed: 56 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,56 @@
1+
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License"
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
import paddle
16+
import paddle.nn as nn
17+
import paddle.nn.functional as F
18+
19+
import paddlenlp as ppnlp
20+
21+
22+
class QuestionMatching(nn.Layer):
23+
def __init__(self, pretrained_model, dropout=None, rdrop_coef=0.0):
24+
super().__init__()
25+
self.ptm = pretrained_model
26+
self.dropout = nn.Dropout(dropout if dropout is not None else 0.1)
27+
28+
# num_labels = 2 (similar or dissimilar)
29+
self.classifier = nn.Linear(self.ptm.config["hidden_size"], 2)
30+
self.rdrop_coef = rdrop_coef
31+
self.rdrop_loss = ppnlp.losses.RDropLoss()
32+
33+
def forward(self,
34+
input_ids,
35+
token_type_ids=None,
36+
position_ids=None,
37+
attention_mask=None,
38+
do_evaluate=False):
39+
40+
_, cls_embedding1 = self.ptm(input_ids, token_type_ids, position_ids,
41+
attention_mask)
42+
cls_embedding1 = self.dropout(cls_embedding1)
43+
logits1 = self.classifier(cls_embedding1)
44+
45+
# For more information about R-drop please refer to this paper: https://arxiv.org/abs/2106.14448
46+
# Original implementation please refer to this code: https://github.com/dropreg/R-Drop
47+
if self.rdrop_coef > 0 and not do_evaluate:
48+
_, cls_embedding2 = self.ptm(input_ids, token_type_ids, position_ids,
49+
attention_mask)
50+
cls_embedding2 = self.dropout(cls_embedding2)
51+
logits2 = self.classifier(cls_embedding2)
52+
kl_loss = self.rdrop_loss(logits1, logits2)
53+
else:
54+
kl_loss = 0.0
55+
56+
return logits1, kl_loss
Lines changed: 120 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,120 @@
1+
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
from functools import partial
16+
import argparse
17+
import sys
18+
import os
19+
import random
20+
import time
21+
22+
import numpy as np
23+
import paddle
24+
import paddle.nn.functional as F
25+
import paddlenlp as ppnlp
26+
from paddlenlp.datasets import load_dataset
27+
from paddlenlp.data import Stack, Tuple, Pad
28+
29+
from data import create_dataloader, read_text_pair, convert_example
30+
from model import QuestionMatching
31+
32+
# yapf: disable
33+
parser = argparse.ArgumentParser()
34+
parser.add_argument("--input_file", type=str, required=True, help="The full path of input file")
35+
parser.add_argument("--result_file", type=str, required=True, help="The result file name")
36+
parser.add_argument("--params_path", type=str, required=True, help="The path to model parameters to be loaded.")
37+
parser.add_argument("--max_seq_length", default=256, type=int, help="The maximum total input sequence length after tokenization. "
38+
"Sequences longer than this will be truncated, sequences shorter will be padded.")
39+
parser.add_argument("--batch_size", default=32, type=int, help="Batch size per GPU/CPU for training.")
40+
parser.add_argument('--device', choices=['cpu', 'gpu'], default="gpu", help="Select which device to train model, defaults to gpu.")
41+
args = parser.parse_args()
42+
# yapf: enable
43+
44+
45+
def predict(model, data_loader):
46+
"""
47+
Predicts the data labels.
48+
49+
Args:
50+
model (obj:`QuestionMatching`): A model to calculate whether the question pair is semantic similar or not.
51+
data_loaer (obj:`List(Example)`): The processed data ids of text pair: [query_input_ids, query_token_type_ids, title_input_ids, title_token_type_ids]
52+
Returns:
53+
results(obj:`List`): cosine similarity of text pairs.
54+
"""
55+
batch_logits = []
56+
57+
model.eval()
58+
59+
with paddle.no_grad():
60+
for batch_data in data_loader:
61+
input_ids, token_type_ids = batch_data
62+
63+
input_ids = paddle.to_tensor(input_ids)
64+
token_type_ids = paddle.to_tensor(token_type_ids)
65+
66+
batch_logit, _ = model(
67+
input_ids=input_ids, token_type_ids=token_type_ids)
68+
69+
batch_logits.append(batch_logit.numpy())
70+
71+
batch_logits = np.concatenate(batch_logits, axis=0)
72+
73+
return batch_logits
74+
75+
76+
if __name__ == "__main__":
77+
paddle.set_device(args.device)
78+
79+
pretrained_model = ppnlp.transformers.ErnieGramModel.from_pretrained(
80+
'ernie-gram-zh')
81+
tokenizer = ppnlp.transformers.ErnieGramTokenizer.from_pretrained(
82+
'ernie-gram-zh')
83+
84+
trans_func = partial(
85+
convert_example,
86+
tokenizer=tokenizer,
87+
max_seq_length=args.max_seq_length,
88+
is_test=True)
89+
90+
batchify_fn = lambda samples, fn=Tuple(
91+
Pad(axis=0, pad_val=tokenizer.pad_token_id), # input_ids
92+
Pad(axis=0, pad_val=tokenizer.pad_token_type_id), # segment_ids
93+
): [data for data in fn(samples)]
94+
95+
test_ds = load_dataset(
96+
read_text_pair, data_path=args.input_file, is_test=True, lazy=False)
97+
98+
test_data_loader = create_dataloader(
99+
test_ds,
100+
mode='predict',
101+
batch_size=args.batch_size,
102+
batchify_fn=batchify_fn,
103+
trans_fn=trans_func)
104+
105+
model = QuestionMatching(pretrained_model)
106+
107+
if args.params_path and os.path.isfile(args.params_path):
108+
state_dict = paddle.load(args.params_path)
109+
model.set_dict(state_dict)
110+
print("Loaded parameters from %s" % args.params_path)
111+
else:
112+
raise ValueError(
113+
"Please set --params_path with correct pretrained model file")
114+
115+
y_probs = predict(model, test_data_loader)
116+
y_preds = np.argmax(y_probs, axis=1)
117+
118+
with open(args.result_file, 'w', encoding="utf-8") as f:
119+
for y_pred in y_preds:
120+
f.write(str(y_pred) + "\n")

0 commit comments

Comments
 (0)