Skip to content

Commit 740f5e2

Browse files
smallv0221LiuChiachiFrostML
authored
Add unimo model and fix generate api (#891)
* fix unified transformer dtype problem * fix win dtype bug * Fix plato-2 and plato-mini dtype bug * Fix plato-2 tokenization * Refine some doc * Add general k support for topk sampling * fix seed * minor fix * Fix unitransformer readme * topk kernel optimization * add unimo model and fix generate api * add 3 datasets for unimo-text Co-authored-by: Jiaqi Liu <[email protected]> Co-authored-by: liu zhengxi <[email protected]>
1 parent 4cdd4c0 commit 740f5e2

File tree

11 files changed

+1831
-18
lines changed

11 files changed

+1831
-18
lines changed
Lines changed: 144 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,144 @@
1+
# 千言:面向事实一致性的生成评测比赛baseline
2+
3+
## 比赛简介
4+
5+
自然语言生成旨在让机器能够像人一样使用自然语言进行表达和交互,它是人工智能领域重要的前沿课题,近年来受到学术界和工业界广泛关注。
6+
7+
随着神经网络生成模型特别是预训练语言模型的迅速发展,机器生成文本的可读性和流畅性不断提升。然而,自动生成的文本中依然经常出现不符合原文或背景的错误事实描述,这种生成的事实一致性问题是自然语言生成进行落地应用的主要障碍之一,并逐渐受到研究学者的关注。鉴于当前国内外关于事实一致性的生成评测比赛十分匮乏,为了促进自然语言生成的技术发展和实际应用,我们计划组织面向事实一致性的生成评测比赛。
8+
9+
在此比赛中,我们将提供三个对事实一致性有较高要求的生成任务,包括文案生成、摘要生成和问题生成。同时,在系统评价中,我们将结合文本流畅性和事实一致性两项指标综合评估参赛生成系统的水平。通过这样的任务设定和评价方式,此评测将有助于研究者和开发者更多关注自然语言生成的事实一致性难题,并为大家提供学术交流平台,从而进一步提升自然语言生成的研究水平,推动相关技术的应用发展。
10+
11+
本比赛得到中国中文信息学会自然语言生成专业委员会(筹)支持,将在2021年11月7日首届中国自然语言生成大会(CCNLG-2021)召开评测研讨会,并在大会上对获奖团队颁奖。
12+
13+
14+
## 快速开始
15+
16+
### 数据准备
17+
18+
比赛使用三个任务数据集测试参赛系统的生成能力,包括文案生成、摘要生成和问题生成:
19+
20+
- 文案生成根据结构化的商品信息生成合适的广告文案;
21+
- 摘要生成是为输入文档生成简洁且包含关键信息的简洁文本;
22+
- 问题生成则是根据给定段落以及答案生成适合的问题。
23+
24+
25+
### 模型训练
26+
27+
运行如下命令即可在样例训练集上进行finetune,并在样例验证集上进行验证
28+
29+
```shell
30+
# GPU启动,参数`--gpus`指定训练所用的GPU卡号,可以是单卡,也可以多卡
31+
unset CUDA_VISIBLE_DEVICES
32+
python -m paddle.distributed.launch --gpus "0" --log_dir ./log run_gen.py \
33+
--dataset_name=dureader_qg \
34+
--model_name_or_path=unimo-text-1.0 \
35+
--save_dir=./unimo/checkpoints \
36+
--logging_steps=100 \
37+
--save_steps=100000 \
38+
--epochs=6 \
39+
--batch_size=16 \
40+
--learning_rate=5e-5 \
41+
--warmup_propotion=0.02 \
42+
--weight_decay=0.01 \
43+
--max_seq_len=512 \
44+
--max_target_len=30 \
45+
--do_train \
46+
--do_predict \
47+
--device=gpu
48+
```
49+
50+
其中参数释义如下:
51+
- `gpus` 指示了训练所用的GPU卡号。
52+
- `dataset_name` 数据集名称,dureader_qg、advertisegen和lcsts_new分别对应问题生成、文案生成和摘要生成三个任务。
53+
- `model_name_or_path` 指示了finetune使用的具体预训练模型,可以是PaddleNLP提供的预训练模型,或者是本地的预训练模型。如果使用本地的预训练模型,可以配置本地模型的目录地址,例如: ./checkpoints/model_xx/,目录中需包含paddle预训练模型model_state.pdparams。如果使用PaddleNLP提供的预训练模型,可以选择下面其中之一。
54+
55+
| PaddleNLP提供的预训练模型 |
56+
|---------------------------------|
57+
| unimo-text-1.0 |
58+
| unimo-text-1.0-large |
59+
60+
- `save_dir` 表示模型的保存路径。
61+
- `logging_steps` 表示日志打印间隔。
62+
- `save_steps` 表示模型保存及评估间隔。
63+
- `seed` 表示随机数生成器的种子。
64+
- `epochs` 表示训练轮数。
65+
- `batch_size` 表示每次迭代**每张卡**上的样本数目。
66+
- `learning_rate` 表示基础学习率大小,将于learning rate scheduler产生的值相乘作为当前学习率。
67+
- `weight_decay` 表示AdamW优化器中使用的weight_decay的系数。
68+
- `warmup_propotion` 表示学习率逐渐升高到基础学习率(即上面配置的learning_rate)所需要的迭代数占总步数的比例,最早的使用可以参考[这篇论文](https://arxiv.org/pdf/1706.02677.pdf)
69+
- `max_seq_len` 模型输入序列的最大长度。
70+
- `max_target_len` 模型训练时标签的最大长度。
71+
- `do_train` 是否进行训练。
72+
- `do_predict` 是否进行预测,在验证集上会自动评估。
73+
- `device` 表示使用的设备,从gpu和cpu中选择。
74+
75+
更多参数详情和参数的默认值请参考`args.py`
76+
77+
程序运行时将会自动进行训练和验证,训练过程中会自动保存模型在指定的`save_dir`中。
78+
如:
79+
```text
80+
./checkpoints/
81+
├── model_8000
82+
│ ├── model_config.json
83+
│ ├── model_state.pdparams
84+
│ ├── spm.model
85+
│ ├── tokenizer_config.json
86+
│ └── vocab.txt
87+
└── ...
88+
```
89+
90+
**NOTE:** 如需恢复模型训练,`model_name_or_path`配置本地模型的目录地址即可。
91+
92+
### 模型预测
93+
94+
运行如下命令即可在样例测试集上进行测试
95+
96+
```shell
97+
export CUDA_VISIBLE_DEVICES=0
98+
# GPU启动,预测仅支持单卡
99+
python infer.py \
100+
--model_name_or_path=./checkpoints/model_80000 \
101+
--test_data_path=./datasets/test.txt \
102+
--output_path=./predict.txt \
103+
--logging_steps=500 \
104+
--seed=2021 \
105+
--batch_size=4 \
106+
--min_dec_len=1 \
107+
--max_dec_len=64 \
108+
--num_samples=20 \
109+
--decode_strategy=sampling \
110+
--top_k=5 \
111+
--device=gpu
112+
```
113+
114+
其中参数释义如下:
115+
- `model_name_or_path` 指示了finetune使用的具体预训练模型,可以是PaddleNLP提供的预训练模型,或者是本地的预训练模型。如果使用本地的预训练模型,可以配置本地模型的目录地址,例如: ./checkpoints/model_xx/,目录中需包含paddle预训练模型model_state.pdparams。如果使用PaddleNLP提供的预训练模型,可以选择下面其中之一。
116+
117+
| PaddleNLP提供的预训练模型 |
118+
|---------------------------------|
119+
| unified_transformer-12L-cn |
120+
| unified_transformer-12L-cn-luge |
121+
122+
- `test_data_path` 表示预测集文件路径。
123+
- `output_path` 表示预测结果的保存路径。
124+
- `logging_steps` 表示日志打印间隔。
125+
- `seed` 表示随机数生成器的种子。
126+
- `batch_size` 表示每次迭代**每张卡**上的样本数目。
127+
- `min_dec_len` 表示预测生成的句子的最小长度。
128+
- `max_dec_len` 表示预测生成的句子的最大长度。
129+
- `num_samples` 表示每条样本生成的句子的数量。对于每条样本,模型会生成`num_samples`个句子,根据每个句子的概率得分进行排序,得分最高的句子作为最终的生成结果。
130+
- `decode_strategy` 表示预测解码时采取的策略,可选"sampling"、"greedy_search"和"beam_search"之一。
131+
- `top_k` 表示采用"sampling"解码策略时,token的概率按从大到小排序,生成的token只从前`top_k`个中进行采样。
132+
- `device` 表示训练使用的设备。
133+
134+
参数详情和参数的默认值请参考`args.py`
135+
136+
程序运行结束后会将预测结果保存在`output_path`中。将预测结果准备成比赛官网要求的格式,提交评估即可得评估结果。
137+
138+
采用不同的模型在样例测试集上有如下结果:
139+
140+
| model_name_or_path | F1 | BLEU1 / BLEU2 | DISTINCT1 / DISTINCT2 |
141+
| :-----------------------------: | :---: | :-----------: | :-------------------: |
142+
| unified_transformer-12L-cn | 10.62 | 0.070 / 0.022 | 0.065 / 0.304 |
143+
| unified_transformer-12L-cn-luge | 33.11 | 0.245 / 0.157 | 0.074 / 0.238 |
144+
| ./checkpoints/model_80000 | 32.38 | 0.239 / 0.150 | 0.070 / 0.219 |
Lines changed: 186 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,186 @@
1+
import random
2+
from functools import partial
3+
4+
import numpy as np
5+
6+
import paddle
7+
import paddle.distributed as dist
8+
from paddle.io import DataLoader, DistributedBatchSampler, BatchSampler
9+
from paddlenlp.data import Pad
10+
11+
12+
def print_args(args):
13+
print('----------- Configuration Arguments -----------')
14+
for arg, value in sorted(vars(args).items()):
15+
print('%s: %s' % (arg, value))
16+
print('------------------------------------------------')
17+
18+
19+
def set_seed(seed):
20+
# Use the same data seed(for data shuffle) for all procs to guarantee data
21+
# consistency after sharding.
22+
random.seed(seed)
23+
np.random.seed(seed)
24+
# Maybe different op seeds(for dropout) for different procs is better.
25+
paddle.seed(seed + dist.get_rank())
26+
27+
28+
def convert_example(example,
29+
tokenizer,
30+
max_seq_len=512,
31+
max_target_len=128,
32+
max_title_len=256,
33+
mode='train'):
34+
"""Convert all examples into necessary features."""
35+
source = example['source']
36+
title = None
37+
if 'title' in example.keys():
38+
title = example['title']
39+
40+
if mode != 'test':
41+
tokenized_example = tokenizer.gen_encode(
42+
source,
43+
title=title,
44+
target=example['target'],
45+
max_seq_len=max_seq_len,
46+
max_target_len=max_target_len,
47+
max_title_len=max_title_len,
48+
return_position_ids=True,
49+
return_length=True)
50+
target_start = tokenized_example['input_ids'].index(
51+
tokenizer.cls_token_id, 1)
52+
target_end = tokenized_example['seq_len']
53+
# Use to gather the logits corresponding to the labels during training
54+
tokenized_example['masked_positions'] = list(
55+
range(target_start, target_end - 1))
56+
tokenized_example['labels'] = tokenized_example['input_ids'][
57+
target_start + 1:target_end]
58+
59+
return tokenized_example
60+
else:
61+
tokenized_example = tokenizer.gen_encode(
62+
source,
63+
title=title,
64+
max_seq_len=max_seq_len,
65+
max_title_len=max_title_len,
66+
add_start_token_for_decoding=True,
67+
return_position_ids=True)
68+
69+
if 'target' in example:
70+
tokenized_example['target'] = example['target']
71+
return tokenized_example
72+
73+
74+
def batchify_fn(batch_examples, pad_val, mode):
75+
def pad_mask(batch_attention_mask):
76+
batch_size = len(batch_attention_mask)
77+
max_len = max(map(len, batch_attention_mask))
78+
attention_mask = np.ones(
79+
(batch_size, max_len, max_len), dtype='float32') * -1e9
80+
for i, mask_data in enumerate(attention_mask):
81+
seq_len = len(batch_attention_mask[i])
82+
mask_data[-seq_len:, -seq_len:] = np.array(
83+
batch_attention_mask[i], dtype='float32')
84+
# In order to ensure the correct broadcasting mechanism, expand one
85+
# dimension to the second dimension (n_head of Transformer).
86+
attention_mask = np.expand_dims(attention_mask, axis=1)
87+
return attention_mask
88+
89+
pad_func = Pad(pad_val=pad_val, pad_right=False, dtype='int64')
90+
91+
input_ids = pad_func([example['input_ids'] for example in batch_examples])
92+
token_type_ids = pad_func(
93+
[example['token_type_ids'] for example in batch_examples])
94+
position_ids = pad_func(
95+
[example['position_ids'] for example in batch_examples])
96+
97+
attention_mask = pad_mask(
98+
[example['attention_mask'] for example in batch_examples])
99+
100+
if mode != 'test':
101+
max_len = max([example['seq_len'] for example in batch_examples])
102+
masked_positions = np.concatenate([
103+
np.array(example['masked_positions']) +
104+
(max_len - example['seq_len']) + i * max_len
105+
for i, example in enumerate(batch_examples)
106+
])
107+
labels = np.concatenate([
108+
np.array(
109+
example['labels'], dtype='int64') for example in batch_examples
110+
])
111+
return input_ids, token_type_ids, position_ids, attention_mask, masked_positions, labels
112+
else:
113+
return input_ids, token_type_ids, position_ids, attention_mask
114+
115+
116+
def create_data_loader(dataset, tokenizer, args, mode):
117+
trans_func = partial(
118+
convert_example,
119+
tokenizer=tokenizer,
120+
max_seq_len=args.max_seq_len,
121+
max_target_len=args.max_target_len,
122+
max_title_len=args.max_title_len,
123+
mode=mode)
124+
dataset = dataset.map(trans_func, lazy=True)
125+
if mode == 'train':
126+
batch_sampler = DistributedBatchSampler(
127+
dataset, batch_size=args.batch_size, shuffle=True)
128+
else:
129+
batch_sampler = BatchSampler(
130+
dataset, batch_size=args.batch_size // 2, shuffle=False)
131+
collate_fn = partial(batchify_fn, pad_val=tokenizer.pad_token_id, mode=mode)
132+
data_loader = DataLoader(
133+
dataset,
134+
batch_sampler=batch_sampler,
135+
collate_fn=collate_fn,
136+
return_list=True)
137+
return dataset, data_loader
138+
139+
140+
def post_process_sum(token_ids, tokenizer):
141+
"""Post-process the decoded sequence. Truncate from the first <eos>."""
142+
eos_pos = len(token_ids)
143+
for i, tok_id in enumerate(token_ids):
144+
if tok_id == tokenizer.mask_token_id:
145+
eos_pos = i
146+
break
147+
token_ids = token_ids[:eos_pos]
148+
tokens = tokenizer.convert_ids_to_tokens(token_ids)
149+
tokens = tokenizer.merge_subword(tokens)
150+
special_tokens = ['[UNK]']
151+
tokens = [token for token in tokens if token not in special_tokens]
152+
return token_ids, tokens
153+
154+
155+
def select_sum(ids, scores, tokenizer, max_dec_len=None,
156+
num_return_sequences=1):
157+
ids = ids.numpy()
158+
scores = scores.numpy()
159+
160+
if len(ids) != len(scores) or (len(ids) % num_return_sequences) != 0:
161+
raise ValueError(
162+
"the length of `ids` is {}, but the `num_return_sequences` is {}".
163+
format(len(ids), num_return_sequences))
164+
165+
group = []
166+
tmp = []
167+
for pred, score in zip(ids, scores):
168+
pred_token_ids, pred_tokens = post_process_sum(pred, tokenizer)
169+
num_token = len(pred_token_ids)
170+
171+
target = "".join(pred_tokens)
172+
173+
# not ending
174+
if max_dec_len is not None and num_token >= max_dec_len:
175+
score -= 1e3
176+
177+
tmp.append([target, score])
178+
if len(tmp) == num_return_sequences:
179+
group.append(tmp)
180+
tmp = []
181+
182+
results = []
183+
for preds in group:
184+
preds = sorted(preds, key=lambda x: -x[1])
185+
results.append(preds[0][0])
186+
return results

0 commit comments

Comments
 (0)