Skip to content

Commit 181a1e0

Browse files
committed
add PET
1 parent 03f0f02 commit 181a1e0

File tree

5 files changed

+17
-378
lines changed

5 files changed

+17
-378
lines changed

examples/few_shot/README.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@ Few-Shot Learning 旨在研究如何从少量有监督的训练样本中学习
1212
| ------------ | ------------ | ------------ | ------------ | ------------ | ------------ | ------------ | ------------ | ------------ |------------ | ------------ | ---------- |
1313
| P-tuning | ERNIE1.0 | 55.70 | 83.28 | 63.43 | 35.36 | 60.54 | 50.02 | 54.51 | 50.14 | 54.93 | 41.16 |
1414
| EFL | ERNIE1.0 | 54.47 | 84.10 | 60.10 | 35.12 | 56.61 | 56.57 | 53.59 | 46.37 | 61.21 | 36.56 |
15-
| PET | ERNIE1.0 | 56.38 | 86.88 | 61.90 | 36.90 | 61.10 | 56.51 | 55.02 | 50.31 | 59.72 | 39.11 |
15+
| PET | ERNIE1.0 | 56.63 | 86.88 | 61.90 | 36.90 | 61.10 | 56.51 | 55.02 | 50.31 | 59.72 | 41.35 |
1616
## 策略库
1717
- [P-tuning](./p-tuning)
1818
- [EFL](./efl)

examples/few_shot/pet/data.py

Lines changed: 7 additions & 167 deletions
Original file line numberDiff line numberDiff line change
@@ -80,19 +80,12 @@ def convert_example(example, tokenizer, max_seq_length=512, is_test=False):
8080
src_ids = encoded_inputs["input_ids"]
8181
token_type_ids = encoded_inputs["token_type_ids"]
8282

83-
# # Step2: gen p_token_ids
84-
# p_tokens = ["[unused{}]".format(i) for i in range(p_embedding_num)]
85-
# p_token_ids = tokenizer.convert_tokens_to_ids(p_tokens)
86-
87-
# Step3: Insert "[MASK]" to src_ids based on start_mask_position
83+
# Step2: Insert "[MASK]" to src_ids based on start_mask_position
8884
src_ids = src_ids[0:start_mask_position] + mask_ids + src_ids[
8985
start_mask_position:]
9086
token_type_ids = token_type_ids[0:start_mask_position] + [0] * len(
9187
mask_ids) + token_type_ids[start_mask_position:]
9288

93-
# Stpe4: Insert P-tokens at begin of sentence
94-
# src_ids = p_token_ids + src_ids
95-
9689
# calculate mask_positions
9790
mask_positions = [
9891
index + start_mask_position for index in range(label_length)
@@ -143,129 +136,6 @@ def convert_example(example, tokenizer, max_seq_length=512, is_test=False):
143136
return src_ids, token_type_ids, mask_positions, mask_lm_labels
144137

145138

146-
def convert_cluewsc_example(example,
147-
tokenizer,
148-
max_seq_length=512,
149-
is_test=False):
150-
"""
151-
Args:
152-
example(obj:`list(str)`): The list of text to be converted to ids.
153-
tokenizer(obj:`PretrainedTokenizer`): This tokenizer inherits from :class:`~paddlenlp.transformers.PretrainedTokenizer`
154-
which contains most of the methods. Users should refer to the superclass for more information regarding methods.
155-
max_seq_len(obj:`int`): The maximum total input sequence length after tokenization.
156-
Sequences longer than this will be truncated, sequences shorter will be padded.
157-
p_embedding_num(obj:`int`) The number of p-embedding.
158-
Returns:
159-
input_ids(obj:`list[int]`): The list of query token ids.
160-
token_type_ids(obj: `list[int]`): List of query sequence pair mask.
161-
mask_positions(obj: `list[int]`): The list of mask_positions.
162-
mask_lm_labels(obj: `list[int]`): The list of mask_lm_labels.
163-
"""
164-
165-
# Replace <unk> with '[MASK]'
166-
167-
# Step1: gen mask ids
168-
if is_test:
169-
label_length = example["label_length"]
170-
else:
171-
text_label = example["text_label"]
172-
label_length = len(text_label)
173-
174-
mask_tokens = ["[MASK]"] * label_length
175-
mask_ids = tokenizer.convert_tokens_to_ids(mask_tokens)
176-
177-
sentence1 = example["sentence1"]
178-
if "<unk>" in sentence1:
179-
start_mask_position = sentence1.index("<unk>") + 1
180-
sentence1 = sentence1.replace("<unk>", "")
181-
encoded_inputs = tokenizer(text=sentence1, max_seq_len=max_seq_length)
182-
src_ids = encoded_inputs["input_ids"]
183-
token_type_ids = encoded_inputs["token_type_ids"]
184-
185-
# # Step2: gen p_token_ids
186-
# p_tokens = ["[unused{}]".format(i) for i in range(p_embedding_num)]
187-
# p_token_ids = tokenizer.convert_tokens_to_ids(p_tokens)
188-
189-
# Step3: Insert "[MASK]" to src_ids based on start_mask_position
190-
src_ids = src_ids[0:start_mask_position] + mask_ids + src_ids[
191-
start_mask_position:]
192-
token_type_ids = token_type_ids[0:start_mask_position] + [0] * len(
193-
mask_ids) + token_type_ids[start_mask_position:]
194-
195-
# Stpe4: Insert P-tokens at begin of sentence
196-
# src_ids = p_token_ids + src_ids
197-
198-
# calculate mask_positions
199-
mask_positions = [
200-
index + start_mask_position for index in range(label_length)
201-
]
202-
else:
203-
sentence2 = example['sentence2']
204-
start_mask_position = sentence2.index("<unk>") + 1
205-
sentence2 = sentence2.replace("<unk>", "")
206-
207-
encoded_inputs = tokenizer(text=sentence2, max_seq_len=max_seq_length)
208-
src_ids = encoded_inputs["input_ids"]
209-
token_type_ids = encoded_inputs["token_type_ids"]
210-
src_ids = src_ids[0:start_mask_position] + mask_ids + src_ids[
211-
start_mask_position:]
212-
token_type_ids = token_type_ids[0:start_mask_position] + [0] * len(
213-
mask_ids) + token_type_ids[start_mask_position:]
214-
215-
encoded_inputs = tokenizer(text=sentence1, max_seq_len=max_seq_length)
216-
sentence1_src_ids = encoded_inputs["input_ids"][1:]
217-
src_ids = sentence1_src_ids + src_ids
218-
token_type_ids += [1] * len(src_ids)
219-
mask_positions = [
220-
index + start_mask_position + len(sentence1)
221-
for index in range(label_length)
222-
]
223-
224-
token_type_ids = [0] * len(src_ids)
225-
226-
assert len(src_ids) == len(
227-
token_type_ids), "length src_ids, token_type_ids must be equal"
228-
229-
length = len(src_ids)
230-
if length > 512:
231-
src_ids = src_ids[:512]
232-
token_type_ids = token_type_ids[:512]
233-
234-
if is_test:
235-
import jieba.posseg as pseg
236-
judge = 0
237-
238-
def isname(single_word_string):
239-
pair_word_list = pseg.lcut(single_word_string)
240-
for eve_word, cixing in pair_word_list:
241-
if cixing == "nr":
242-
return True
243-
return False
244-
245-
text_ori = example["target"]["span1_text"]
246-
text_daici = example["target"]["span2_text"]
247-
if isname(text_ori) and text_daici == "它":
248-
judge = 1
249-
if ("妈" in text_ori or "姨" in text_ori or "婆" in text_ori or
250-
"太太" in text_ori or "妻" in text_ori or "姐" in text_ori or
251-
"妹" in text_ori) and ("他" in text_daici):
252-
judge = 1
253-
if ("爸" in text_ori or "叔" in text_ori or "公" in text_ori or
254-
"夫" in text_ori or "哥" in text_ori or
255-
"弟" in text_ori) and ("她" in text_daici):
256-
judge = 1
257-
# print(paddle.to_tensor(judge, dtype="int64"))
258-
return src_ids, token_type_ids, mask_positions, judge
259-
else:
260-
mask_lm_labels = tokenizer(
261-
text=text_label, max_seq_len=max_seq_length)["input_ids"][1:-1]
262-
assert len(mask_lm_labels) == len(
263-
mask_positions
264-
) == label_length, "length of mask_lm_labels:{} mask_positions:{} label_length:{} not equal".format(
265-
mask_lm_labels, mask_positions, text_label)
266-
return src_ids, token_type_ids, mask_positions, mask_lm_labels
267-
268-
269139
def convert_chid_example(example, tokenizer, max_seq_length=512, is_test=False):
270140
"""
271141
Args:
@@ -283,17 +153,12 @@ def convert_chid_example(example, tokenizer, max_seq_length=512, is_test=False):
283153
mask_lm_labels(obj: `list[int]`): The list of mask_lm_labels.
284154
"""
285155
# FewClue Task `Chid`' label's position must be calculated by special token: "淠"
286-
# FewClue Task `Chid`' label's position must be calculated by special token: "龜"
287156

288157
seg_tokens = tokenizer.tokenize(example["sentence1"])
289158

290159
# find insert position of `[MASK]`
291-
# start_mask_position = seg_tokens.index("淠") + 1
292-
# seg_tokens.remove("淠")
293-
# start_mask_position = seg_tokens.index("龜") + 1
294-
# seg_tokens.remove("龜")
295-
start_mask_position = seg_tokens.index("[UNK]") + 1
296-
seg_tokens.remove("[UNK]")
160+
start_mask_position = seg_tokens.index("淠") + 1
161+
seg_tokens.remove("淠")
297162

298163
sentence1 = "".join(seg_tokens)
299164
candidates = example["candidates"]
@@ -317,19 +182,12 @@ def convert_chid_example(example, tokenizer, max_seq_length=512, is_test=False):
317182
mask_tokens = ["[MASK]"] * label_length
318183
mask_ids = tokenizer.convert_tokens_to_ids(mask_tokens)
319184

320-
# Step2: gen p_token_ids
321-
# p_tokens = ["[unused{}]".format(i) for i in range(p_embedding_num)]
322-
# p_token_ids = tokenizer.convert_tokens_to_ids(p_tokens)
323-
324-
# Step3: Insert "[MASK]" to src_ids based on start_mask_position
185+
# Step2: Insert "[MASK]" to src_ids based on start_mask_position
325186
src_ids = src_ids[0:start_mask_position] + mask_ids + src_ids[
326187
start_mask_position:]
327188
token_type_ids = token_type_ids[0:start_mask_position] + [0] * len(
328189
mask_ids) + token_type_ids[start_mask_position:]
329190

330-
# Stpe4: Insert P-tokens at begin of sentence
331-
# src_ids = p_token_ids + src_ids
332-
333191
# calculate mask_positions
334192
mask_positions = [
335193
index + start_mask_position for index in range(label_length)
@@ -577,12 +435,6 @@ def transform_csldcp(example,
577435

578436
if pattern_id == 0:
579437
example["sentence1"] = u'这篇关于<unk>的文章讲了' + example["content"]
580-
# elif pattern_id == 1:
581-
# example["sentence1"] = example["content"] + u'这是一篇关于<unk>的文章'
582-
# elif pattern_id == 2:
583-
# example["sentence1"] = example["content"] + u'这是和<unk>有关的文章'
584-
# elif pattern_id == 3:
585-
# example["sentence1"] = example["content"] + u'这些与<unk>有关'
586438
elif pattern_id == 1:
587439
example["sentence1"] = example["content"] + u'和<unk>息息相关'
588440
elif pattern_id == 2:
@@ -599,12 +451,6 @@ def transform_csldcp(example,
599451
example['text_label'] = normalized_label
600452
if pattern_id == 0:
601453
example["sentence1"] = u'这篇关于<unk>的文章讲了' + example["content"]
602-
# elif pattern_id == 1:
603-
# example["sentence1"] = example["content"] + u'这是一篇关于<unk>的文章'
604-
# elif pattern_id == 2:
605-
# example["sentence1"] = example["content"] + u'这是和<unk>有关的文章'
606-
# elif pattern_id == 3:
607-
# example["sentence1"] = example["content"] + u'这些与<unk>有关'
608454
elif pattern_id == 1:
609455
example["sentence1"] = example["content"] + u'和<unk>息息相关'
610456
elif pattern_id == 2:
@@ -661,9 +507,7 @@ def transform_chid(example,
661507

662508
if is_test:
663509
example["label_length"] = 4
664-
# example["sentence1"] = example["content"].replace("#idiom#", "淠")
665-
# example["sentence1"] = example["content"].replace("#idiom#", "龜")
666-
example["sentence1"] = example["content"].replace("#idiom#", "蠅")
510+
example["sentence1"] = example["content"].replace("#idiom#", "淠")
667511
del example["content"]
668512

669513
return example
@@ -675,11 +519,7 @@ def transform_chid(example,
675519
# Note: `#idom#` represent a idom which must be replaced with rarely-used Chinese characters
676520
# to get the label's position after the text processed by tokenizer
677521
#ernie
678-
# example["sentence1"] = example["content"].replace("#idiom#", "淠")
679-
#albert
680-
# example["sentence1"] = example["content"].replace("#idiom#", "龜")
681-
#macbert
682-
example["sentence1"] = example["content"].replace("#idiom#", "蠅")
522+
example["sentence1"] = example["content"].replace("#idiom#", "淠")
683523
del example["content"]
684524

685525
return example
@@ -744,4 +584,4 @@ def transform_cluewsc(example,
744584
"csldcp": transform_csldcp,
745585
"cluewsc": transform_cluewsc,
746586
"chid": transform_chid
747-
}
587+
}

examples/few_shot/pet/evaluate.py

Lines changed: 1 addition & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -34,11 +34,6 @@ def do_evaluate(model, tokenizer, data_loader, label_normalize_dict):
3434

3535
for batch in data_loader:
3636
src_ids, token_type_ids, masked_positions, masked_lm_labels = batch
37-
# [bs * label_length, vocab_size]
38-
# prediction_probs = model.predict(
39-
# input_ids=src_ids,
40-
# token_type_ids=token_type_ids,
41-
# masked_positions=masked_positions)
4237

4338
max_len = src_ids.shape[1]
4439
new_masked_positions = []
@@ -113,44 +108,10 @@ def do_evaluate_cluewsc(model, tokenizer, data_loader, label_normalize_dict):
113108
prediction_probs = model(
114109
input_ids=src_ids, token_type_ids=token_type_ids)
115110

116-
# max_len = src_ids.shape[1]
117-
# new_masked_positions = []
118-
119-
# for bs_index, mask_pos in enumerate(masked_positions.numpy()):
120-
# for pos in mask_pos:
121-
# new_masked_positions.append(bs_index * max_len + pos)
122-
# new_masked_positions = paddle.to_tensor(np.array(new_masked_positions).astype('int32'))
123-
124-
# prediction_scores, _ = model(
125-
# input_ids=src_ids,
126-
# token_type_ids=token_type_ids,
127-
# masked_positions=new_masked_positions)
128-
# softmax_fn = paddle.nn.Softmax()
129-
# prediction_probs = softmax_fn(prediction_scores)
130-
131-
# batch_size = len(src_ids)
132-
# vocab_size = 2
133-
134-
# # prediction_probs: [batch_size, label_lenght, vocab_size]
135-
# prediction_probs = paddle.reshape(
136-
# prediction_probs, shape=[batch_size, -1, vocab_size]).numpy()
137-
138-
# # [label_num, label_length]
139-
# label_ids = np.array(
140-
# [tokenizer(label)["input_ids"][1:-1] for label in normed_labels])
141-
142-
# y_pred = np.ones(shape=[batch_size, len(label_ids)])
143-
144-
# # Calculate joint distribution of candidate labels
145-
# for index in range(label_length):
146-
# y_pred *= prediction_probs[:, index, label_ids[:, index]]
147-
148111
# Get max probs label's index
149112
y_pred_index = paddle.argmax(prediction_probs, axis=-1).numpy()
150113
y_true_index = []
151-
# print(y_pred_index)
152-
# print(label_idx)
153-
# print()
114+
154115
for label_i in label_idx.numpy():
155116
y_true_index.append(label_i)
156117
y_true_index = np.array(y_true_index)

0 commit comments

Comments
 (0)