Skip to content

Commit 10ac335

Browse files
authored
Merge pull request #1420 from linjieccc/up_wordtag
Update usage of wordtag
2 parents 072b9ff + d3fcc61 commit 10ac335

File tree

9 files changed

+125
-105
lines changed

9 files changed

+125
-105
lines changed

docs/model_zoo/taskflow.md

Lines changed: 52 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -83,9 +83,9 @@ seg(["第十四届全运会在西安举办", "三亚是一个美丽的城市"])
8383

8484
#### 自定义词典
8585

86-
用户可以通过装载自定义词典来定制化分词结果。
86+
用户可以通过装载自定义词典来定制化分词结果。词典文件每一行表示一个自定义item,可以由一个单词或者多个单词组成。
8787

88-
词典文件`custom_seg.txt`示例:
88+
词典文件`user_dict.txt`示例:
8989

9090
```text
9191
平原上的火焰
@@ -103,14 +103,15 @@ seg(["第十四届全运会在西安举办", "三亚是一个美丽的城市"])
103103
```python
104104
from paddlenlp import Taskflow
105105

106-
my_seg = Taskflow("word_segmentation", custom_vocab="custom_seg.txt")
106+
my_seg = Taskflow("word_segmentation", user_dict="user_dict.txt")
107107
my_seg("平原上的火焰计划于年末上映")
108108
>>> ['平原上的火焰', '计划', '', '', '', '上映']
109109
```
110110

111111
#### 可配置参数说明
112112

113113
* `batch_size`:批处理大小,请结合机器情况进行调整,默认为1。
114+
* `user_dict`:用户自定义词典文件,默认为None。
114115

115116
### 词性标注
116117

@@ -139,9 +140,9 @@ tag(["第十四届全运会在西安举办", "三亚是一个美丽的城市"])
139140

140141
#### 自定义词典
141142

142-
用户可以通过装载自定义词典来定制化分词和词性标注结果。
143+
用户可以通过装载自定义词典来定制化分词和词性标注结果。词典文件每一行表示一个自定义item,可以由一个单词或者多个单词组成,单词后面可以添加自定义标签,格式为`item/tag`,如果不添加自定义标签,则使用模型默认标签。
143144

144-
词典文件`custom_pos.txt`示例:
145+
词典文件`user_dict.txt`示例:
145146

146147
```text
147148
赛里木湖/LAKE
@@ -161,14 +162,15 @@ tag(["第十四届全运会在西安举办", "三亚是一个美丽的城市"])
161162
```python
162163
from paddlenlp import Taskflow
163164

164-
my_pos = Taskflow("pos_tagging", custom_vocab="custom_pos.txt")
165+
my_pos = Taskflow("pos_tagging", user_dict="user_dict.txt")
165166
my_pos("赛里木湖是新疆海拔最高的高山湖泊")
166167
>>> [('赛里木湖', 'LAKE'), ('', 'v'), ('新疆', 'LOC'), ('海拔最高', 'n'), ('', 'u'), ('', 'a'), ('', 'n'), ('', 'n'), ('', 'n')]
167168
```
168169

169170
#### 可配置参数说明
170171

171172
* `batch_size`:批处理大小,请结合机器情况进行调整,默认值为1。
173+
* `user_dict`:用户自定义词典文件,默认为None。
172174

173175
### 命名实体识别
174176

@@ -183,9 +185,53 @@ ner(["热梅茶是一道以梅子为主要原料制作的茶饮", "《孤女》
183185
>>> [[('热梅茶', '饮食类_饮品'), ('', '肯定词'), ('一道', '数量词'), ('', '介词'), ('梅子', '饮食类'), ('', '肯定词'), ('主要原料', '物体类'), ('制作', '场景事件'), ('', '助词'), ('茶饮', '饮食类_饮品')], [('', 'w'), ('孤女', '作品类_实体'), ('', 'w'), ('', '肯定词'), ('2010年', '时间类'), ('九州出版社', '组织机构类'), ('出版', '场景事件'), ('', '助词'), ('小说', '作品类_概念'), ('', 'w'), ('作者', '人物类_概念'), ('', '肯定词'), ('余兼羽', '人物类_实体')]]
184186
```
185187

188+
#### 自定义词典
189+
190+
用户可以通过装载自定义词典来定制化分词和词性标注结果。词典文件每一行表示一个自定义item,可以由一个单词或者多个单词组成,单词后面可以添加自定义标签,格式为`item/tag`,如果不添加自定义标签,则使用模型默认标签。
191+
192+
词典文件`user_dict.txt`示例:
193+
194+
```text
195+
长津湖/电影类_实体
196+
收/词汇用语 尾/术语类
197+
最 大
198+
海外票仓
199+
```
200+
201+
以"《长津湖》收尾,北美是最大海外票仓"为例,原本的输出结果为:
202+
203+
```text
204+
[('《', 'w'), ('长津湖', '作品类_实体'), ('》', 'w'), ('收尾', '场景事件'), (',', 'w'), ('北美', '世界地区类'), ('是', '肯定词'), ('最大', '修饰词'), ('海外', '场所类'), ('票仓', '词汇用语')]
205+
```
206+
207+
装载自定义词典及输出结果示例:
208+
209+
```python
210+
from paddlenlp import Taskflow
211+
212+
my_ner = Taskflow("ner", user_dict="user_dict.txt")
213+
my_ner("《长津湖》收尾,北美是最大海外票仓")
214+
>>> [('', 'w'), ('长津湖', '电影类_实体'), ('', 'w'), ('', '词汇用语'), ('', '术语类'), ('', 'w'), ('北美', '世界地区类'), ('', '肯定词'), ('', '修饰词'), ('', '修饰词'), ('海外票仓', '场所类')]
215+
```
216+
217+
#### 自定义NER模型
218+
219+
用户可以使用自己的数据训练自定义NER模型,参考[NER-WordTag增量训练示例](https://github.com/PaddlePaddle/PaddleNLP/tree/develop/examples/text_to_knowledge/ernie-ctm)
220+
221+
使用Taskflow加载自定义模型进行一键预测:
222+
223+
```shell
224+
from paddlenlp import Taskflow
225+
226+
my_ner = Taskflow("ner", params_path="/path/to/your/params", tag_path="/path/to/your/tag")
227+
```
228+
186229
#### 可配置参数说明
187230

188231
* `batch_size`:批处理大小,请结合机器情况进行调整,默认为1。
232+
* `user_dict`:用户自定义词典文件,默认为None。
233+
* `params_path`:模型参数文件路径,默认为None。
234+
* `tag_path`:标签文件路径,默认为None。
189235

190236
### 文本纠错
191237

examples/text_to_knowledge/ernie-ctm/README.md

Lines changed: 0 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -81,15 +81,6 @@ data/
8181
《/w 全球化与中国:理论与发展趋势/作品类_实体 》/w 是/肯定词 2010年/时间类 经济管理出版社/组织机构类 出版/场景事件 的/助词 图书/作品类_概念 ,/w 作者/人物类_概念 是/肯定词 余永定/人物类_实体 、/w 路爱国/人物类_实体 、/w 高海红/人物类_实体 。/w
8282
```
8383

84-
WordTag模型使用了**BIOES标注体系**,用户可以在标签文件中(该示例为`tags.txt`)按照该标注体系自定义添加词性或命名实体类别,标签文件示例:
85-
86-
```text
87-
B-组织机构类_企事业单位
88-
I-组织机构类_企事业单位
89-
E-组织机构类_企事业单位
90-
S-组织机构类_企事业单位
91-
```
92-
9384
#### 模型训练
9485

9586
```shell

examples/text_to_knowledge/ernie-ctm/predict.py

Lines changed: 3 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -78,8 +78,7 @@ def do_predict(data,
7878
input_ids = paddle.to_tensor(input_ids)
7979
token_type_ids = paddle.to_tensor(token_type_ids)
8080
seq_len = paddle.to_tensor(seq_len)
81-
logits, _ = model(input_ids, token_type_ids)
82-
_, pred_tags = viterbi_decoder(logits, seq_len)
81+
pred_tags = model(input_ids, token_type_ids, lengths=seq_len)
8382
all_pred_tags.extend(pred_tags.numpy().tolist())
8483
results = decode(data, all_pred_tags, summary_num, idx_to_tags)
8584
return results
@@ -95,14 +94,9 @@ def do_predict(data,
9594
tags_to_idx = load_dict(os.path.join(args.data_dir, "tags.txt"))
9695
idx_to_tags = dict(zip(*(tags_to_idx.values(), tags_to_idx.keys())))
9796

98-
crf = LinearChainCrf(len(tags_to_idx), 100, with_start_stop_tag=False)
99-
viterbi_decoder = ViterbiDecoder(crf.transitions, False)
100-
10197
model = ErnieCtmWordtagModel.from_pretrained(
10298
"wordtag",
103-
num_tag=len(tags_to_idx),
104-
num_cls_label=4,
105-
ignore_index=tags_to_idx["O"])
99+
num_tag=len(tags_to_idx))
106100
tokenizer = ErnieCtmTokenizer.from_pretrained("wordtag")
107101

108102
if args.params_path and os.path.isfile(args.params_path):
@@ -113,7 +107,7 @@ def do_predict(data,
113107
results = do_predict(data,
114108
model,
115109
tokenizer,
116-
viterbi_decoder,
110+
model.viterbi_decoder,
117111
tags_to_idx,
118112
idx_to_tags,
119113
batch_size=args.batch_size)

examples/text_to_knowledge/ernie-ctm/train.py

Lines changed: 8 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -64,17 +64,17 @@ def set_seed(seed):
6464

6565

6666
@paddle.no_grad()
67-
def evaluate(model, metric, criterion, data_loader, tags, tags_to_idx):
67+
def evaluate(model, metric, data_loader, tags, tags_to_idx):
6868
model.eval()
6969
metric.reset()
7070
losses = []
7171
for batch in data_loader():
7272
input_ids, token_type_ids, seq_len, tags = batch
73-
seq_logits, _ = model(input_ids,
73+
loss, seq_logits = model(input_ids,
7474
token_type_ids,
7575
lengths=seq_len,
7676
tag_labels=tags)
77-
loss = criterion(seq_logits, seq_len, tags).mean()
77+
loss = loss.mean()
7878
losses.append(loss.numpy())
7979

8080
correct = metric.compute(
@@ -109,9 +109,9 @@ def do_train(args):
109109
tokenizer = ErnieCtmTokenizer.from_pretrained("wordtag")
110110
model = ErnieCtmWordtagModel.from_pretrained(
111111
"wordtag",
112-
num_tag=len(tags_to_idx),
113-
num_cls_label=4,
114-
ignore_index=tags_to_idx["O"])
112+
num_tag=len(tags_to_idx))
113+
model.crf_loss = LinearChainCrfLoss(
114+
LinearChainCrf(len(tags_to_idx), 0.1, with_start_stop_tag=False))
115115

116116
trans_func = partial(
117117
convert_example,
@@ -170,9 +170,6 @@ def do_train(args):
170170
logger.info("WarmUp steps: %s" % warmup)
171171

172172
metric = SequenceAccuracy()
173-
crf_lr = 0.1
174-
crf = LinearChainCrf(len(tags_to_idx), crf_lr, with_start_stop_tag=False)
175-
criterion = LinearChainCrfLoss(crf)
176173

177174
total_loss = 0
178175
global_step = 0
@@ -185,12 +182,11 @@ def do_train(args):
185182
global_step += 1
186183
input_ids, token_type_ids, seq_len, tags = batch
187184

188-
seq_logits, _ = model(
185+
loss, _ = model(
189186
input_ids,
190187
token_type_ids,
191188
lengths=seq_len,
192189
tag_labels=tags)
193-
loss = criterion(seq_logits, seq_len, tags)
194190
loss = loss.mean()
195191
total_loss += loss
196192
loss.backward()
@@ -219,7 +215,7 @@ def do_train(args):
219215
model_to_save.save_pretrained(output_dir)
220216
tokenizer.save_pretrained(output_dir)
221217

222-
evaluate(model, metric, criterion, dev_data_loader, tags, tags_to_idx)
218+
evaluate(model, metric, dev_data_loader, tags, tags_to_idx)
223219

224220

225221
def print_arguments(args):

paddlenlp/taskflow/knowledge_mining.py

Lines changed: 8 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -128,7 +128,7 @@
128128
from paddlenlp import Taskflow
129129
130130
# 默认使用WordTag词类知识标注工具
131-
wordtag = Taskflow("knowledge_mining")
131+
wordtag = Taskflow("knowledge_mining", model="wordtag")
132132
wordtag("《孤女》是2010年九州出版社出版的小说,作者是余兼羽")
133133
'''
134134
[{'text': '《孤女》是2010年九州出版社出版的小说,作者是余兼羽', 'items': [{'item': '《', 'offset': 0, 'wordtag_label': 'w', 'length': 1}, {'item': '孤女', 'offset': 1, 'wordtag_label': '作品类_实体', 'length': 2}, {'item': '》', 'offset': 3, 'wordtag_label': 'w', 'length': 1}, {'item': '是', 'offset': 4, 'wordtag_label': '肯定词', 'length': 1, 'termid': '肯定否定词_cb_是'}, {'item': '2010年', 'offset': 5, 'wordtag_label': '时间类', 'length': 5, 'termid': '时间阶段_cb_2010年'}, {'item': '九州出版社', 'offset': 10, 'wordtag_label': '组织机构类', 'length': 5, 'termid': '组织机构_eb_九州出版社'}, {'item': '出版', 'offset': 15, 'wordtag_label': '场景事件', 'length': 2, 'termid': '场景事件_cb_出版'}, {'item': '的', 'offset': 17, 'wordtag_label': '助词', 'length': 1, 'termid': '助词_cb_的'}, {'item': '小说', 'offset': 18, 'wordtag_label': '作品类_概念', 'length': 2, 'termid': '小说_cb_小说'}, {'item': ',', 'offset': 20, 'wordtag_label': 'w', 'length': 1}, {'item': '作者', 'offset': 21, 'wordtag_label': '人物类_概念', 'length': 2, 'termid': '人物_cb_作者'}, {'item': '是', 'offset': 23, 'wordtag_label': '肯定词', 'length': 1, 'termid': '肯定否定词_cb_是'}, {'item': '余兼羽', 'offset': 24, 'wordtag_label': '人物类_实体', 'length': 3}]}]
@@ -207,8 +207,6 @@ def __init__(self,
207207
self._termtree = TermTree.from_dir(term_schema_path, term_data_path,
208208
self._linking)
209209

210-
self.crf = LinearChainCrf(len(self._tags_to_index), 100, with_start_stop_tag=False)
211-
self._viterbi_decoder = ViterbiDecoder(self.crf.transitions, False)
212210
self._usage = usage
213211
self._summary_num = 2
214212

@@ -510,6 +508,9 @@ def _construct_input_spec(self):
510508
paddle.static.InputSpec(shape=[None, None],
511509
dtype="int64",
512510
name="token_type_ids"), # token_type_ids
511+
paddle.static.InputSpec(shape=[None],
512+
dtype="int64",
513+
name="seq_len"), # seq_len
513514
]
514515

515516
def _construct_model(self, model):
@@ -518,9 +519,7 @@ def _construct_model(self, model):
518519
"""
519520
model_instance = ErnieCtmWordtagModel.from_pretrained(
520521
model,
521-
num_cls_label=4,
522-
num_tag=len(self._tags_to_index),
523-
ignore_index=self._tags_to_index["O"])
522+
num_tag=len(self._tags_to_index))
524523
config_keys = ErnieCtmWordtagModel.pretrained_init_configuration[
525524
self.model]
526525
self.kwargs.update(config_keys)
@@ -554,11 +553,10 @@ def _run_model(self, inputs):
554553
input_ids, token_type_ids, seq_len = batch
555554
self.input_handles[0].copy_from_cpu(input_ids.numpy())
556555
self.input_handles[1].copy_from_cpu(token_type_ids.numpy())
556+
self.input_handles[2].copy_from_cpu(seq_len.numpy())
557557
self.predictor.run()
558-
logits = self.output_handle[0].copy_to_cpu()
559-
score, pred_tags = self._viterbi_decoder(
560-
paddle.to_tensor(logits), seq_len)
561-
all_pred_tags.extend(pred_tags.numpy().tolist())
558+
pred_tags = self.output_handle[0].copy_to_cpu()
559+
all_pred_tags.extend(pred_tags.tolist())
562560
inputs['all_pred_tags'] = all_pred_tags
563561
return inputs
564562

paddlenlp/taskflow/lexical_analysis.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -94,8 +94,8 @@ class LacTask(Task):
9494
def __init__(self, task, model, **kwargs):
9595
super().__init__(task=task, model=model, **kwargs)
9696
self._usage = usage
97-
self._custom_vocab = self.kwargs[
98-
'custom_vocab'] if 'custom_vocab' in self.kwargs else None
97+
self._user_dict = self.kwargs[
98+
'user_dict'] if 'user_dict' in self.kwargs else None
9999
word_dict_path = download_file(
100100
self._task_path, "lac_params" + os.path.sep + "word.dic",
101101
URLS['lac_params'][0], URLS['lac_params'][1])
@@ -113,9 +113,9 @@ def __init__(self, task, model, **kwargs):
113113
self._id2tag_dict = dict(
114114
zip(self._tag_vocab.values(), self._tag_vocab.keys()))
115115
self._get_inference_model()
116-
if self._custom_vocab:
116+
if self._user_dict:
117117
self._custom = Customization()
118-
self._custom.load_customization(self._custom_vocab)
118+
self._custom.load_customization(self._user_dict)
119119
else:
120120
self._custom = None
121121

paddlenlp/taskflow/named_entity_recognition.py

Lines changed: 11 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@
2424
from .utils import download_file
2525
from .utils import TermTree
2626
from .knowledge_mining import WordTagTask
27+
from .utils import Customization
2728

2829
usage = r"""
2930
from paddlenlp import Taskflow
@@ -35,8 +36,7 @@
3536
'''
3637
3738
ner = Taskflow("ner")
38-
ner(["热梅茶是一道以梅子为主要原料制作的茶饮",
39-
"《孤女》是2010年九州出版社出版的小说,作者是余兼羽"])
39+
ner(["热梅茶是一道以梅子为主要原料制作的茶饮", "《孤女》是2010年九州出版社出版的小说,作者是余兼羽"])
4040
'''
4141
[[('热梅茶', '饮食类_饮品'), ('是', '肯定词'), ('一道', '数量词'), ('以', '介词'), ('梅子', '饮食类'), ('为', '肯定词'), ('主要原料', '物体类'), ('制作', '场景事件'), ('的', '助词'), ('茶饮', '饮食类_饮品')], [('《', 'w'), ('孤女', '作品类_实体'), ('》', 'w'), ('是', '肯定词'), ('2010年', '时间类'), ('九州出版社', '组织机构类'), ('出版', '场景事件'), ('的', '助词'), ('小说', '作品类_概念'), (',', 'w'), ('作者', '人物类_概念'), ('是', '肯定词'), ('余兼羽', '人物类_实体')]]
4242
'''
@@ -56,6 +56,13 @@ class NERTask(WordTagTask):
5656

5757
def __init__(self, model, task, **kwargs):
5858
super().__init__(model=model, task=task, **kwargs)
59+
self._user_dict = self.kwargs[
60+
'user_dict'] if 'user_dict' in self.kwargs else None
61+
if self._user_dict:
62+
self._custom = Customization()
63+
self._custom.load_customization(self._user_dict)
64+
else:
65+
self._custom = None
5966

6067
def _decode(self, batch_texts, batch_pred_tags):
6168
batch_results = []
@@ -65,7 +72,8 @@ def _decode(self, batch_texts, batch_pred_tags):
6572
for index in batch_pred_tags[sent_index][self.summary_num:-1]
6673
]
6774
sent = batch_texts[sent_index]
68-
75+
if self._custom:
76+
self._custom.parse_customization(sent, tags, prefix=True)
6977
sent_out = []
7078
tags_out = []
7179
partial_word = ""

0 commit comments

Comments
 (0)