Skip to content

Commit a96c61f

Browse files
authored
Add SimBERT for text matching (#706)
* add SimBERT * add with_pool for BertModel * add SimBERT * add SimBERT * add SimBERT * add SimBERT * add SimBERT * add SimBERT * fix some problem in README.md * fix the description of simbert * fix the description of simbert * fix the description of simbert * fix the description of simbert * fix the description of simbert
1 parent c59e496 commit a96c61f

File tree

6 files changed

+261
-6
lines changed

6 files changed

+261
-6
lines changed

docs/model_zoo/transformers.rst

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -9,8 +9,8 @@ PaddleNLP为用户提供了常用的 ``BERT``、``ERNIE``、``ALBERT``、``RoBER
99
Transformer预训练模型汇总
1010
------------------------------------
1111

12-
下表汇总了介绍了目前PaddleNLP支持的各类预训练模型以及对应预训练权重。我们目前提供了 **67** 种预训练的参数权重供用户使用,
13-
其中包含了 **32** 种中文语言模型的预训练权重。
12+
下表汇总了介绍了目前PaddleNLP支持的各类预训练模型以及对应预训练权重。我们目前提供了 **68** 种预训练的参数权重供用户使用,
13+
其中包含了 **33** 种中文语言模型的预训练权重。
1414

1515
+--------------------+-------------------------------------+--------------+-----------------------------------------+
1616
| Model | Pretrained Weight | Language | Details of the model |
@@ -115,6 +115,11 @@ Transformer预训练模型汇总
115115
| | | | Trained on cased Chinese Simplified |
116116
| | | | and Traditional text using |
117117
| | | | Whole-Word-Masking with extented data. |
118+
| +-------------------------------------+--------------+-----------------------------------------+
119+
| |``simbert-base-chinese`` | Chinese | 12-layer, 768-hidden, |
120+
| | | | 12-heads, 108M parameters. |
121+
| | | | Trained on 22 million pairs of similar |
122+
| | | | sentences crawed from Baidu Know. |
118123
+--------------------+-------------------------------------+--------------+-----------------------------------------+
119124
|BigBird_ |``bigbird-base-uncased`` | English | 12-layer, 768-hidden, |
120125
| | | | 12-heads, _M parameters. |
@@ -427,6 +432,7 @@ Reference
427432
`huggingface/xlnet_chinese_large <https://huggingface.co/clue/xlnet_chinese_large>`_,
428433
`Knover/luge-dialogue <https://github.com/PaddlePaddle/Knover/tree/luge-dialogue/luge-dialogue>`_,
429434
`huawei-noah/Pretrained-Language-Model/NEZHA-PyTorch/ <https://github.com/huawei-noah/Pretrained-Language-Model/tree/master/NEZHA-PyTorch>`_
435+
`ZhuiyiTechnology/simbert <https://github.com/ZhuiyiTechnology/simbert>`_
430436
- Lan, Zhenzhong, et al. "Albert: A lite bert for self-supervised learning of language representations." arXiv preprint arXiv:1909.11942 (2019).
431437
- Devlin, Jacob, et al. "Bert: Pre-training of deep bidirectional transformers for language understanding." arXiv preprint arXiv:1810.04805 (2018).
432438
- Zaheer, Manzil, et al. "Big bird: Transformers for longer sequences." arXiv preprint arXiv:2007.14062 (2020).
Lines changed: 50 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,50 @@
1+
# SimBERT模型
2+
3+
## 模型简介
4+
[SimBERT](https://github.com/ZhuiyiTechnology/simbert)的模型权重是以Google开源的BERT模型为基础,基于微软的UniLM思想设计了融检索与生成于一体的任务,来进一步微调后得到的模型,所以它同时具备相似问生成和相似句检索能力。
5+
6+
## 快速开始
7+
8+
### 代码结构说明
9+
10+
以下是本项目主要代码结构及说明:
11+
12+
```text
13+
simbert/
14+
├── data.py #训练样本的数据加载以及转换
15+
├── predict.py # 模型预测
16+
└── README.md # 文档说明
17+
```
18+
19+
### 模型预测
20+
21+
启动预测:
22+
```shell
23+
export CUDA_VISIBLE_DEVICES=0
24+
python predict.py --input_file ./datasets/lcqmc/dev.tsv
25+
```
26+
27+
待预测数据如以下示例:
28+
29+
30+
```text
31+
世界上什么东西最小 世界上什么东西最小?
32+
光眼睛大就好看吗 眼睛好看吗?
33+
小蝌蚪找妈妈怎么样 小蝌蚪找妈妈是谁画的
34+
```
35+
36+
按照predict.py.py进行预测得到相似度
37+
38+
39+
40+
```text
41+
{'query': '世界上什么东西最小', 'title': '世界上什么东西最小?', 'similarity': 0.992725}
42+
{'query': '光眼睛大就好看吗', 'title': '眼睛好看吗?', 'similarity': 0.74502724}
43+
{'query': '小蝌蚪找妈妈怎么样', 'title': '小蝌蚪找妈妈是谁画的', 'similarity': 0.8192148}
44+
```
45+
46+
## Reference
47+
48+
关于SimBERT更多信息参考[科学空间](https://spaces.ac.cn/archives/7427)
49+
50+
SimBERT项目地址 https://github.com/ZhuiyiTechnology/simbert
Lines changed: 52 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,52 @@
1+
import paddle
2+
import numpy as np
3+
4+
from paddlenlp.datasets import MapDataset
5+
6+
7+
def create_dataloader(dataset,
8+
mode='train',
9+
batch_size=1,
10+
batchify_fn=None,
11+
trans_fn=None):
12+
if trans_fn:
13+
dataset = dataset.map(trans_fn)
14+
15+
shuffle = True if mode == 'train' else False
16+
if mode == 'train':
17+
batch_sampler = paddle.io.DistributedBatchSampler(
18+
dataset, batch_size=batch_size, shuffle=shuffle)
19+
else:
20+
batch_sampler = paddle.io.BatchSampler(
21+
dataset, batch_size=batch_size, shuffle=shuffle)
22+
23+
return paddle.io.DataLoader(
24+
dataset=dataset,
25+
batch_sampler=batch_sampler,
26+
collate_fn=batchify_fn,
27+
return_list=True)
28+
29+
30+
def read_text_pair(data_path):
31+
"""Reads data."""
32+
with open(data_path, 'r', encoding='utf-8') as f:
33+
for line in f:
34+
data = line.rstrip().split("\t")
35+
if len(data) != 2:
36+
continue
37+
yield {'query': data[0], 'title': data[1]}
38+
39+
40+
def convert_example(example, tokenizer, max_seq_length=512, phase="train"):
41+
42+
query, title = example['query'], example['title']
43+
44+
query_encoded_inputs = tokenizer(text=query, max_seq_len=max_seq_length)
45+
query_input_ids = query_encoded_inputs["input_ids"]
46+
query_token_type_ids = query_encoded_inputs["token_type_ids"]
47+
title_encoded_inputs = tokenizer(text=title, max_seq_len=max_seq_length)
48+
49+
title_input_ids = title_encoded_inputs["input_ids"]
50+
title_token_type_ids = title_encoded_inputs["token_type_ids"]
51+
52+
return query_input_ids, query_token_type_ids, title_input_ids, title_token_type_ids
Lines changed: 123 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,123 @@
1+
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
from functools import partial
16+
import argparse
17+
import sys
18+
import os
19+
import random
20+
import time
21+
22+
import numpy as np
23+
import paddle
24+
import paddle.nn.functional as F
25+
import paddlenlp as ppnlp
26+
from paddlenlp.datasets import load_dataset
27+
from paddlenlp.data import Stack, Tuple, Pad
28+
29+
from data import create_dataloader, read_text_pair
30+
from data import convert_example
31+
32+
# yapf: disable
33+
parser = argparse.ArgumentParser()
34+
parser.add_argument("--input_file", type=str, required=True, help="The full path of input file")
35+
# parser.add_argument("--params_path", type=str, required=True, help="The path to model parameters to be loaded.")
36+
parser.add_argument("--max_seq_length", default=64, type=int, help="The maximum total input sequence length after tokenization. "
37+
"Sequences longer than this will be truncated, sequences shorter will be padded.")
38+
parser.add_argument("--batch_size", default=32, type=int, help="Batch size per GPU/CPU for training.")
39+
parser.add_argument('--device', choices=['cpu', 'gpu'], default="gpu", help="Select which device to train model, defaults to gpu.")
40+
args = parser.parse_args()
41+
# yapf: enable
42+
43+
44+
def predict(model, data_loader):
45+
"""
46+
Predicts the similarity.
47+
48+
Args:
49+
model (obj:`SemanticIndexBase`): A model to extract text embedding or calculate similarity of text pair.
50+
data_loaer (obj:`List(Example)`): The processed data ids of text pair: [query_input_ids, query_token_type_ids, title_input_ids, title_token_type_ids]
51+
Returns:
52+
results(obj:`List`): cosine similarity of text pairs.
53+
"""
54+
results = []
55+
56+
model.eval()
57+
58+
with paddle.no_grad():
59+
for batch_data in data_loader:
60+
query_input_ids, query_token_type_ids, title_input_ids, title_token_type_ids = batch_data
61+
query_input_ids = paddle.to_tensor(query_input_ids)
62+
query_token_type_ids = paddle.to_tensor(query_token_type_ids)
63+
title_input_ids = paddle.to_tensor(title_input_ids)
64+
title_token_type_ids = paddle.to_tensor(title_token_type_ids)
65+
66+
vecs_query = model(
67+
input_ids=query_input_ids, token_type_ids=query_token_type_ids)
68+
vecs_title = model(
69+
input_ids=title_input_ids, token_type_ids=title_token_type_ids)
70+
vecs_query = vecs_query[1].numpy()
71+
vecs_title = vecs_title[1].numpy()
72+
73+
vecs_query = vecs_query / (vecs_query**2).sum(axis=1,
74+
keepdims=True)**0.5
75+
vecs_title = vecs_title / (vecs_title**2).sum(axis=1,
76+
keepdims=True)**0.5
77+
sims = (vecs_query * vecs_title).sum(axis=1)
78+
79+
results.extend(sims)
80+
81+
return results
82+
83+
84+
if __name__ == "__main__":
85+
paddle.set_device(args.device)
86+
87+
model = ppnlp.transformers.BertModel.from_pretrained(
88+
'simbert-base-chinese', with_pool='linear')
89+
tokenizer = ppnlp.transformers.BertTokenizer.from_pretrained(
90+
'simbert-base-chinese')
91+
92+
trans_func = partial(
93+
convert_example,
94+
tokenizer=tokenizer,
95+
max_seq_length=args.max_seq_length,
96+
phase="predict")
97+
98+
batchify_fn = lambda samples, fn=Tuple(
99+
Pad(axis=0, pad_val=tokenizer.pad_token_id), # query_input
100+
Pad(axis=0, pad_val=tokenizer.pad_token_type_id), # query_segment
101+
Pad(axis=0, pad_val=tokenizer.pad_token_id), # title_input
102+
Pad(axis=0, pad_val=tokenizer.pad_token_type_id), # tilte_segment
103+
): [data for data in fn(samples)]
104+
105+
valid_ds = load_dataset(
106+
read_text_pair, data_path=args.input_file, lazy=False)
107+
108+
valid_data_loader = create_dataloader(
109+
valid_ds,
110+
mode='predict',
111+
batch_size=args.batch_size,
112+
batchify_fn=batchify_fn,
113+
trans_fn=trans_func)
114+
115+
y_sims = predict(model, valid_data_loader)
116+
117+
valid_ds = load_dataset(
118+
read_text_pair, data_path=args.input_file, lazy=False)
119+
120+
for idx, prob in enumerate(y_sims):
121+
text_pair = valid_ds[idx]
122+
text_pair["similarity"] = y_sims[idx]
123+
print(text_pair)

paddlenlp/transformers/bert/modeling.py

Lines changed: 23 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -75,17 +75,19 @@ class BertPooler(Layer):
7575
"""
7676
"""
7777

78-
def __init__(self, hidden_size):
78+
def __init__(self, hidden_size, with_pool):
7979
super(BertPooler, self).__init__()
8080
self.dense = nn.Linear(hidden_size, hidden_size)
8181
self.activation = nn.Tanh()
82+
self.with_pool = with_pool
8283

8384
def forward(self, hidden_states):
8485
# We "pool" the model by simply taking the hidden state corresponding
8586
# to the first token.
8687
first_token_tensor = hidden_states[:, 0]
8788
pooled_output = self.dense(first_token_tensor)
88-
pooled_output = self.activation(pooled_output)
89+
if self.with_pool == 'tanh':
90+
pooled_output = self.activation(pooled_output)
8991
return pooled_output
9092

9193

@@ -253,6 +255,20 @@ class BertPretrainedModel(PretrainedModel):
253255
"initializer_range": 0.02,
254256
"pad_token_id": 0,
255257
},
258+
"simbert-base-chinese": {
259+
"vocab_size": 13685,
260+
"hidden_size": 768,
261+
"num_hidden_layers": 12,
262+
"num_attention_heads": 12,
263+
"intermediate_size": 3072,
264+
"hidden_act": "gelu",
265+
"hidden_dropout_prob": 0.1,
266+
"attention_probs_dropout_prob": 0.1,
267+
"max_position_embeddings": 512,
268+
"type_vocab_size": 2,
269+
"initializer_range": 0.02,
270+
"pad_token_id": 0,
271+
},
256272
}
257273
resource_files_names = {"model_state": "model_state.pdparams"}
258274
pretrained_resource_files_map = {
@@ -279,6 +295,8 @@ class BertPretrainedModel(PretrainedModel):
279295
"https://paddlenlp.bj.bcebos.com/models/transformers/macbert/macbert-base-chinese.pdparams",
280296
"macbert-large-chinese":
281297
"https://paddlenlp.bj.bcebos.com/models/transformers/macbert/macbert-large-chinese.pdparams",
298+
"simbert-base-chinese":
299+
"https://paddlenlp.bj.bcebos.com/models/transformers/simbert/simbert-base-chinese-v1.pdparams",
282300
}
283301
}
284302
base_model_prefix = "bert"
@@ -353,7 +371,8 @@ def __init__(self,
353371
max_position_embeddings=512,
354372
type_vocab_size=16,
355373
initializer_range=0.02,
356-
pad_token_id=0):
374+
pad_token_id=0,
375+
with_pool='tanh'):
357376
super(BertModel, self).__init__()
358377
self.pad_token_id = pad_token_id
359378
self.initializer_range = initializer_range
@@ -369,7 +388,7 @@ def __init__(self,
369388
attn_dropout=attention_probs_dropout_prob,
370389
act_dropout=0)
371390
self.encoder = nn.TransformerEncoder(encoder_layer, num_hidden_layers)
372-
self.pooler = BertPooler(hidden_size)
391+
self.pooler = BertPooler(hidden_size, with_pool)
373392
self.apply(self.init_weights)
374393

375394
def forward(self,

paddlenlp/transformers/bert/tokenizer.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -271,6 +271,8 @@ class BertTokenizer(PretrainedTokenizer):
271271
"https://paddle-hapi.bj.bcebos.com/models/bert/bert-base-chinese-vocab.txt",
272272
"macbert-base-chinese":
273273
"https://paddle-hapi.bj.bcebos.com/models/bert/bert-base-chinese-vocab.txt",
274+
"simbert-base-chinese":
275+
"https://paddlenlp.bj.bcebos.com/models/transformers/simbert/vocab.txt",
274276
}
275277
}
276278
pretrained_init_configuration = {
@@ -307,6 +309,9 @@ class BertTokenizer(PretrainedTokenizer):
307309
"macbert-base-chinese": {
308310
"do_lower_case": False
309311
},
312+
"simbert-base-chinese":{
313+
"do_lower_case": True
314+
},
310315
}
311316
padding_side = 'right'
312317

0 commit comments

Comments
 (0)