Skip to content

Commit 97b52bc

Browse files
author
tianxin
authored
add ernie_matching point-wise & pair-wise (#404)
* add ernie_matching point-wise & pair-wise * Rename some Class * fix some typo
1 parent 2e64244 commit 97b52bc

File tree

6 files changed

+919
-0
lines changed

6 files changed

+919
-0
lines changed
Lines changed: 152 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,152 @@
1+
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
import paddle
16+
import numpy as np
17+
18+
from paddlenlp.datasets import MapDataset
19+
20+
21+
def create_dataloader(dataset,
22+
mode='train',
23+
batch_size=1,
24+
batchify_fn=None,
25+
trans_fn=None):
26+
if trans_fn:
27+
dataset = dataset.map(trans_fn)
28+
29+
shuffle = True if mode == 'train' else False
30+
if mode == 'train':
31+
batch_sampler = paddle.io.DistributedBatchSampler(
32+
dataset, batch_size=batch_size, shuffle=shuffle)
33+
else:
34+
batch_sampler = paddle.io.BatchSampler(
35+
dataset, batch_size=batch_size, shuffle=shuffle)
36+
37+
return paddle.io.DataLoader(
38+
dataset=dataset,
39+
batch_sampler=batch_sampler,
40+
collate_fn=batchify_fn,
41+
return_list=True)
42+
43+
44+
def read_text_pair(data_path):
45+
"""Reads data."""
46+
with open(data_path, 'r', encoding='utf-8') as f:
47+
for line in f:
48+
data = line.rstrip().split("\t")
49+
if len(data) != 2:
50+
continue
51+
yield {'query': data[0], 'title': data[1]}
52+
53+
54+
def convert_pointwise_example(example,
55+
tokenizer,
56+
max_seq_length=512,
57+
is_test=False):
58+
59+
query, title = example["query"], example["title"]
60+
61+
encoded_inputs = tokenizer(
62+
text=query, text_pair=title, max_seq_len=max_seq_length)
63+
64+
input_ids = encoded_inputs["input_ids"]
65+
token_type_ids = encoded_inputs["token_type_ids"]
66+
67+
if not is_test:
68+
label = np.array([example["label"]], dtype="int64")
69+
return input_ids, token_type_ids, label
70+
else:
71+
return input_ids, token_type_ids
72+
73+
74+
def convert_pairwise_example(example,
75+
tokenizer,
76+
max_seq_length=512,
77+
phase="train"):
78+
79+
if phase == "train":
80+
query, pos_title, neg_title = example["query"], example[
81+
"title"], example["neg_title"]
82+
83+
pos_inputs = tokenizer(
84+
text=query, text_pair=pos_title, max_seq_len=max_seq_length)
85+
neg_inputs = tokenizer(
86+
text=query, text_pair=neg_title, max_seq_len=max_seq_length)
87+
88+
pos_input_ids = pos_inputs["input_ids"]
89+
pos_token_type_ids = pos_inputs["token_type_ids"]
90+
neg_input_ids = neg_inputs["input_ids"]
91+
neg_token_type_ids = neg_inputs["token_type_ids"]
92+
93+
return (pos_input_ids, pos_token_type_ids, neg_input_ids,
94+
neg_token_type_ids)
95+
96+
else:
97+
query, title = example["query"], example["title"]
98+
99+
inputs = tokenizer(
100+
text=query, text_pair=title, max_seq_len=max_seq_length)
101+
102+
input_ids = inputs["input_ids"]
103+
token_type_ids = inputs["token_type_ids"]
104+
if phase == "eval":
105+
return input_ids, token_type_ids, example["label"]
106+
elif phase == "predict":
107+
return input_ids, token_type_ids
108+
else:
109+
raise ValueError("not supported phase:{}".format(phase))
110+
111+
112+
def gen_pair(dataset, pool_size=100):
113+
"""
114+
Generate triplet randomly based on dataset
115+
116+
Args:
117+
dataset: A `MapDataset` or `IterDataset` or a tuple of those.
118+
Each example is composed of 2 texts: exampe["query"], example["title"]
119+
pool_size: the number of example to sample negative example randomly
120+
121+
Return:
122+
dataset: A `MapDataset` or `IterDataset` or a tuple of those.
123+
Each example is composed of 2 texts: exampe["query"], example["pos_title"]、example["neg_title"]
124+
"""
125+
126+
if len(dataset) < pool_size:
127+
pool_size = len(dataset)
128+
129+
new_examples = []
130+
pool = []
131+
tmp_exmaples = []
132+
133+
for example in dataset:
134+
label = example["label"]
135+
136+
# Filter negative example
137+
if label == 0:
138+
continue
139+
140+
tmp_exmaples.append(example)
141+
pool.append(example["title"])
142+
143+
if len(pool) >= pool_size:
144+
np.random.shuffle(pool)
145+
for idx, example in enumerate(tmp_exmaples):
146+
example["neg_title"] = pool[idx]
147+
new_examples.append(example)
148+
tmp_exmaples = []
149+
pool = []
150+
else:
151+
continue
152+
return MapDataset(new_examples)
Lines changed: 101 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,101 @@
1+
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License"
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
import paddle
16+
import paddle.nn as nn
17+
import paddle.nn.functional as F
18+
19+
20+
class PointwiseMatching(nn.Layer):
21+
def __init__(self, pretrained_model, dropout=None):
22+
super().__init__()
23+
self.ptm = pretrained_model
24+
self.dropout = nn.Dropout(dropout if dropout is not None else 0.1)
25+
26+
# num_labels = 2 (similar or dissimilar)
27+
self.classifier = nn.Linear(self.ptm.config["hidden_size"], 2)
28+
29+
def forward(self,
30+
input_ids,
31+
token_type_ids=None,
32+
position_ids=None,
33+
attention_mask=None):
34+
35+
_, cls_embedding = self.ptm(input_ids, token_type_ids, position_ids,
36+
attention_mask)
37+
38+
cls_embedding = self.dropout(cls_embedding)
39+
logits = self.classifier(cls_embedding)
40+
probs = F.softmax(logits)
41+
42+
return probs
43+
44+
45+
class PairwiseMatching(nn.Layer):
46+
def __init__(self, pretrained_model, dropout=None, margin=0.1):
47+
super().__init__()
48+
self.ptm = pretrained_model
49+
self.dropout = nn.Dropout(dropout if dropout is not None else 0.1)
50+
self.margin = margin
51+
52+
# hidden_size -> 1, calculate similarity
53+
self.similarity = nn.Linear(self.ptm.config["hidden_size"], 1)
54+
55+
def predict(self,
56+
input_ids,
57+
token_type_ids=None,
58+
position_ids=None,
59+
attention_mask=None):
60+
61+
_, cls_embedding = self.ptm(input_ids, token_type_ids, position_ids,
62+
attention_mask)
63+
64+
cls_embedding = self.dropout(cls_embedding)
65+
sim_score = self.similarity(cls_embedding)
66+
sim_score = F.sigmoid(sim_score)
67+
68+
return sim_score
69+
70+
def forward(self,
71+
pos_input_ids,
72+
neg_input_ids,
73+
pos_token_type_ids=None,
74+
neg_token_type_ids=None,
75+
pos_position_ids=None,
76+
neg_position_ids=None,
77+
pos_attention_mask=None,
78+
neg_attention_mask=None):
79+
80+
_, pos_cls_embedding = self.ptm(pos_input_ids, pos_token_type_ids,
81+
pos_position_ids, pos_attention_mask)
82+
83+
_, neg_cls_embedding = self.ptm(neg_input_ids, neg_token_type_ids,
84+
neg_position_ids, neg_attention_mask)
85+
86+
pos_embedding = self.dropout(pos_cls_embedding)
87+
neg_embedding = self.dropout(neg_cls_embedding)
88+
89+
pos_sim = self.similarity(pos_embedding)
90+
neg_sim = self.similarity(neg_embedding)
91+
92+
pos_sim = F.sigmoid(pos_sim)
93+
neg_sim = F.sigmoid(neg_sim)
94+
95+
labels = paddle.full(
96+
shape=[pos_cls_embedding.shape[0]], fill_value=1.0, dtype='float32')
97+
98+
loss = F.margin_ranking_loss(
99+
pos_sim, neg_sim, labels, margin=self.margin)
100+
101+
return loss
Lines changed: 122 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,122 @@
1+
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
from functools import partial
16+
import argparse
17+
import sys
18+
import os
19+
import random
20+
import time
21+
22+
import numpy as np
23+
import paddle
24+
import paddle.nn.functional as F
25+
import paddlenlp as ppnlp
26+
from paddlenlp.datasets import load_dataset
27+
from paddlenlp.data import Stack, Tuple, Pad
28+
29+
from data import create_dataloader, read_text_pair
30+
from data import convert_pairwise_example as convert_example
31+
from model import PairwiseMatching
32+
33+
# yapf: disable
34+
parser = argparse.ArgumentParser()
35+
parser.add_argument("--input_file", type=str, required=True, help="The full path of input file")
36+
parser.add_argument("--params_path", type=str, required=True, help="The path to model parameters to be loaded.")
37+
parser.add_argument("--max_seq_length", default=64, type=int, help="The maximum total input sequence length after tokenization. "
38+
"Sequences longer than this will be truncated, sequences shorter will be padded.")
39+
parser.add_argument("--batch_size", default=32, type=int, help="Batch size per GPU/CPU for training.")
40+
parser.add_argument('--device', choices=['cpu', 'gpu'], default="gpu", help="Select which device to train model, defaults to gpu.")
41+
args = parser.parse_args()
42+
# yapf: enable
43+
44+
45+
def predict(model, data_loader):
46+
"""
47+
Predicts the data labels.
48+
49+
Args:
50+
model (obj:`SemanticIndexBase`): A model to extract text embedding or calculate similarity of text pair.
51+
data_loaer (obj:`List(Example)`): The processed data ids of text pair: [query_input_ids, query_token_type_ids, title_input_ids, title_token_type_ids]
52+
Returns:
53+
results(obj:`List`): cosine similarity of text pairs.
54+
"""
55+
batch_probs = []
56+
57+
model.eval()
58+
59+
with paddle.no_grad():
60+
for batch_data in data_loader:
61+
input_ids, token_type_ids = batch_data
62+
63+
input_ids = paddle.to_tensor(input_ids)
64+
token_type_ids = paddle.to_tensor(token_type_ids)
65+
66+
batch_prob = model.predict(
67+
input_ids=input_ids, token_type_ids=token_type_ids).numpy()
68+
69+
batch_probs.append(batch_prob)
70+
71+
batch_probs = np.concatenate(batch_probs, axis=0)
72+
73+
return batch_probs
74+
75+
76+
if __name__ == "__main__":
77+
paddle.set_device(args.device)
78+
79+
tokenizer = ppnlp.transformers.ErnieTokenizer.from_pretrained('ernie-1.0')
80+
81+
trans_func = partial(
82+
convert_example,
83+
tokenizer=tokenizer,
84+
max_seq_length=args.max_seq_length,
85+
phase="predict")
86+
87+
batchify_fn = lambda samples, fn=Tuple(
88+
Pad(axis=0, pad_val=tokenizer.pad_token_id), # input_ids
89+
Pad(axis=0, pad_val=tokenizer.pad_token_type_id), # segment_ids
90+
): [data for data in fn(samples)]
91+
92+
valid_ds = load_dataset(
93+
read_text_pair, data_path=args.input_file, lazy=False)
94+
95+
valid_data_loader = create_dataloader(
96+
valid_ds,
97+
mode='predict',
98+
batch_size=args.batch_size,
99+
batchify_fn=batchify_fn,
100+
trans_fn=trans_func)
101+
102+
pretrained_model = ppnlp.transformers.ErnieModel.from_pretrained(
103+
"ernie-1.0")
104+
105+
model = PairwiseMatching(pretrained_model)
106+
107+
if args.params_path and os.path.isfile(args.params_path):
108+
state_dict = paddle.load(args.params_path)
109+
model.set_dict(state_dict)
110+
print("Loaded parameters from %s" % args.params_path)
111+
else:
112+
raise ValueError(
113+
"Please set --params_path with correct pretrained model file")
114+
115+
y_probs = predict(model, valid_data_loader)
116+
117+
valid_ds = load_dataset(
118+
read_text_pair, data_path=args.input_file, lazy=False)
119+
for idx, prob in enumerate(y_probs):
120+
text_pair = valid_ds[idx]
121+
text_pair["pred_prob"] = prob[0]
122+
print(text_pair)

0 commit comments

Comments
 (0)