Skip to content

Commit 18f3a7c

Browse files
committed
Update usage of wordtag
1 parent 0287e54 commit 18f3a7c

File tree

8 files changed

+129
-95
lines changed

8 files changed

+129
-95
lines changed

docs/model_zoo/taskflow.md

Lines changed: 60 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -111,6 +111,7 @@ my_seg("平原上的火焰计划于年末上映")
111111
#### 可配置参数说明
112112

113113
* `batch_size`:批处理大小,请结合机器情况进行调整,默认为1。
114+
* `custom_vocab`:用户自定义词典文件,默认为None。
114115

115116
### 词性标注
116117

@@ -169,6 +170,7 @@ my_pos("赛里木湖是新疆海拔最高的高山湖泊")
169170
#### 可配置参数说明
170171

171172
* `batch_size`:批处理大小,请结合机器情况进行调整,默认值为1。
173+
* `custom_vocab`:用户自定义词典文件,默认为None。
172174

173175
### 命名实体识别
174176

@@ -183,9 +185,67 @@ ner(["热梅茶是一道以梅子为主要原料制作的茶饮", "《孤女》
183185
>>> [[('热梅茶', '饮食类_饮品'), ('', '肯定词'), ('一道', '数量词'), ('', '介词'), ('梅子', '饮食类'), ('', '肯定词'), ('主要原料', '物体类'), ('制作', '场景事件'), ('', '助词'), ('茶饮', '饮食类_饮品')], [('', 'w'), ('孤女', '作品类_实体'), ('', 'w'), ('', '肯定词'), ('2010年', '时间类'), ('九州出版社', '组织机构类'), ('出版', '场景事件'), ('', '助词'), ('小说', '作品类_概念'), ('', 'w'), ('作者', '人物类_概念'), ('', '肯定词'), ('余兼羽', '人物类_实体')]]
184186
```
185187

188+
- 标签集合:
189+
190+
|人物类_实体|物体类|生物类_动物|医学术语类|链接地址|肯定词|
191+
|人物类_概念|物体类_兵器|品牌名|术语类_生物体|个性特征|否定词|
192+
|作品类_实体|物体类_化学物质|场所类|疾病损伤类|感官特征|数量词|
193+
|作品类_概念|其他角色类|场所类_交通场所|疾病损伤类_植物病虫害|场景事件|叹词|
194+
|组织机构类|文化类|位置方位|宇宙类|介词|拟声词|
195+
|组织机构类_企事业单位|文化类_语言文字|世界地区类|事件类|介词_方位介词|修饰词|
196+
|组织机构类_医疗卫生机构|文化类_奖项赛事活动|饮食类|时间类|助词|外语单词|
197+
|组织机构类_国家机关|文化类_制度政策协议|饮食类_菜品|时间类_特殊日|代词|英语单词|
198+
|组织机构类_体育组织机构|文化类_姓氏与人名|饮食类_饮品|术语类|连词|汉语拼音|
199+
|组织机构类_教育组织机构|生物类|药物类|术语类_符号指标类|副词|词汇用语|
200+
|组织机构类_军事组织机构|生物类_植物|药物类_中药|信息资料|疑问词|w(标点)|
201+
202+
#### 自定义词典
203+
204+
用户可以通过装载自定义词典来定制化分词和词性标注结果。
205+
206+
词典文件`custom_ner.txt`示例:
207+
208+
```text
209+
长津湖/电影类_实体
210+
收/词汇用语 尾/术语类
211+
最 大
212+
海外票仓
213+
```
214+
215+
以"《长津湖》收尾,北美是最大海外票仓"为例,原本的输出结果为:
216+
217+
```text
218+
[('《', 'w'), ('长津湖', '作品类_实体'), ('》', 'w'), ('收尾', '场景事件'), (',', 'w'), ('北美', '世界地区类'), ('是', '肯定词'), ('最大', '修饰词'), ('海外', '场所类'), ('票仓', '词汇用语')]
219+
```
220+
221+
装载自定义词典及输出结果示例:
222+
223+
```python
224+
from paddlenlp import Taskflow
225+
226+
my_ner = Taskflow("ner", custom_vocab="custom_ner.txt")
227+
my_ner("《长津湖》收尾,北美是最大海外票仓")
228+
>>> [('', 'w'), ('长津湖', '电影类_实体'), ('', 'w'), ('', '词汇用语'), ('', '术语类'), ('', 'w'), ('北美', '世界地区类'), ('', '肯定词'), ('', '修饰词'), ('', '修饰词'), ('海外票仓', '场所类')]
229+
```
230+
231+
#### 自定义NER模型
232+
233+
用户可以使用自己的数据训练自定义NER模型,参考[NER-WordTag增量训练示例](https://github.com/PaddlePaddle/PaddleNLP/tree/develop/examples/text_to_knowledge/ernie-ctm)
234+
235+
使用Taskflow加载自定义模型进行一键预测:
236+
237+
```shell
238+
from paddlenlp import Taskflow
239+
240+
my_ner = Taskflow("ner", params_path="/path/to/your/params", tag_path="/path/to/your/tag")
241+
```
242+
186243
#### 可配置参数说明
187244

188245
* `batch_size`:批处理大小,请结合机器情况进行调整,默认为1。
246+
* `custom_vocab`:用户自定义词典文件,默认为None。
247+
* `params_path`:模型参数文件路径,默认为None。
248+
* `tag_path`:标签文件路径,默认为None。
189249

190250
### 文本纠错
191251

examples/text_to_knowledge/ernie-ctm/README.md

Lines changed: 0 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -81,15 +81,6 @@ data/
8181
《/w 全球化与中国:理论与发展趋势/作品类_实体 》/w 是/肯定词 2010年/时间类 经济管理出版社/组织机构类 出版/场景事件 的/助词 图书/作品类_概念 ,/w 作者/人物类_概念 是/肯定词 余永定/人物类_实体 、/w 路爱国/人物类_实体 、/w 高海红/人物类_实体 。/w
8282
```
8383

84-
WordTag模型使用了**BIOES标注体系**,用户可以在标签文件中(该示例为`tags.txt`)按照该标注体系自定义添加词性或命名实体类别,标签文件示例:
85-
86-
```text
87-
B-组织机构类_企事业单位
88-
I-组织机构类_企事业单位
89-
E-组织机构类_企事业单位
90-
S-组织机构类_企事业单位
91-
```
92-
9384
#### 模型训练
9485

9586
```shell

examples/text_to_knowledge/ernie-ctm/predict.py

Lines changed: 3 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -78,8 +78,7 @@ def do_predict(data,
7878
input_ids = paddle.to_tensor(input_ids)
7979
token_type_ids = paddle.to_tensor(token_type_ids)
8080
seq_len = paddle.to_tensor(seq_len)
81-
logits, _ = model(input_ids, token_type_ids)
82-
_, pred_tags = viterbi_decoder(logits, seq_len)
81+
pred_tags = model(input_ids, token_type_ids, lengths=seq_len)
8382
all_pred_tags.extend(pred_tags.numpy().tolist())
8483
results = decode(data, all_pred_tags, summary_num, idx_to_tags)
8584
return results
@@ -95,14 +94,9 @@ def do_predict(data,
9594
tags_to_idx = load_dict(os.path.join(args.data_dir, "tags.txt"))
9695
idx_to_tags = dict(zip(*(tags_to_idx.values(), tags_to_idx.keys())))
9796

98-
crf = LinearChainCrf(len(tags_to_idx), 100, with_start_stop_tag=False)
99-
viterbi_decoder = ViterbiDecoder(crf.transitions, False)
100-
10197
model = ErnieCtmWordtagModel.from_pretrained(
10298
"wordtag",
103-
num_tag=len(tags_to_idx),
104-
num_cls_label=4,
105-
ignore_index=tags_to_idx["O"])
99+
num_tag=len(tags_to_idx))
106100
tokenizer = ErnieCtmTokenizer.from_pretrained("wordtag")
107101

108102
if args.params_path and os.path.isfile(args.params_path):
@@ -113,7 +107,7 @@ def do_predict(data,
113107
results = do_predict(data,
114108
model,
115109
tokenizer,
116-
viterbi_decoder,
110+
model.viterbi_decoder,
117111
tags_to_idx,
118112
idx_to_tags,
119113
batch_size=args.batch_size)

examples/text_to_knowledge/ernie-ctm/train.py

Lines changed: 8 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -64,17 +64,17 @@ def set_seed(seed):
6464

6565

6666
@paddle.no_grad()
67-
def evaluate(model, metric, criterion, data_loader, tags, tags_to_idx):
67+
def evaluate(model, metric, data_loader, tags, tags_to_idx):
6868
model.eval()
6969
metric.reset()
7070
losses = []
7171
for batch in data_loader():
7272
input_ids, token_type_ids, seq_len, tags = batch
73-
seq_logits, _ = model(input_ids,
73+
loss, seq_logits = model(input_ids,
7474
token_type_ids,
7575
lengths=seq_len,
7676
tag_labels=tags)
77-
loss = criterion(seq_logits, seq_len, tags).mean()
77+
loss = loss.mean()
7878
losses.append(loss.numpy())
7979

8080
correct = metric.compute(
@@ -109,9 +109,9 @@ def do_train(args):
109109
tokenizer = ErnieCtmTokenizer.from_pretrained("wordtag")
110110
model = ErnieCtmWordtagModel.from_pretrained(
111111
"wordtag",
112-
num_tag=len(tags_to_idx),
113-
num_cls_label=4,
114-
ignore_index=tags_to_idx["O"])
112+
num_tag=len(tags_to_idx))
113+
model.crf_loss = LinearChainCrfLoss(
114+
LinearChainCrf(len(tags_to_idx), 0.1, with_start_stop_tag=False))
115115

116116
trans_func = partial(
117117
convert_example,
@@ -170,9 +170,6 @@ def do_train(args):
170170
logger.info("WarmUp steps: %s" % warmup)
171171

172172
metric = SequenceAccuracy()
173-
crf_lr = 0.1
174-
crf = LinearChainCrf(len(tags_to_idx), crf_lr, with_start_stop_tag=False)
175-
criterion = LinearChainCrfLoss(crf)
176173

177174
total_loss = 0
178175
global_step = 0
@@ -185,12 +182,11 @@ def do_train(args):
185182
global_step += 1
186183
input_ids, token_type_ids, seq_len, tags = batch
187184

188-
seq_logits, _ = model(
185+
loss, _ = model(
189186
input_ids,
190187
token_type_ids,
191188
lengths=seq_len,
192189
tag_labels=tags)
193-
loss = criterion(seq_logits, seq_len, tags)
194190
loss = loss.mean()
195191
total_loss += loss
196192
loss.backward()
@@ -219,7 +215,7 @@ def do_train(args):
219215
model_to_save.save_pretrained(output_dir)
220216
tokenizer.save_pretrained(output_dir)
221217

222-
evaluate(model, metric, criterion, dev_data_loader, tags, tags_to_idx)
218+
evaluate(model, metric, dev_data_loader, tags, tags_to_idx)
223219

224220

225221
def print_arguments(args):

paddlenlp/taskflow/knowledge_mining.py

Lines changed: 8 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -128,7 +128,7 @@
128128
from paddlenlp import Taskflow
129129
130130
# 默认使用WordTag词类知识标注工具
131-
wordtag = Taskflow("knowledge_mining")
131+
wordtag = Taskflow("knowledge_mining", model="wordtag")
132132
wordtag("《孤女》是2010年九州出版社出版的小说,作者是余兼羽")
133133
'''
134134
[{'text': '《孤女》是2010年九州出版社出版的小说,作者是余兼羽', 'items': [{'item': '《', 'offset': 0, 'wordtag_label': 'w', 'length': 1}, {'item': '孤女', 'offset': 1, 'wordtag_label': '作品类_实体', 'length': 2}, {'item': '》', 'offset': 3, 'wordtag_label': 'w', 'length': 1}, {'item': '是', 'offset': 4, 'wordtag_label': '肯定词', 'length': 1, 'termid': '肯定否定词_cb_是'}, {'item': '2010年', 'offset': 5, 'wordtag_label': '时间类', 'length': 5, 'termid': '时间阶段_cb_2010年'}, {'item': '九州出版社', 'offset': 10, 'wordtag_label': '组织机构类', 'length': 5, 'termid': '组织机构_eb_九州出版社'}, {'item': '出版', 'offset': 15, 'wordtag_label': '场景事件', 'length': 2, 'termid': '场景事件_cb_出版'}, {'item': '的', 'offset': 17, 'wordtag_label': '助词', 'length': 1, 'termid': '助词_cb_的'}, {'item': '小说', 'offset': 18, 'wordtag_label': '作品类_概念', 'length': 2, 'termid': '小说_cb_小说'}, {'item': ',', 'offset': 20, 'wordtag_label': 'w', 'length': 1}, {'item': '作者', 'offset': 21, 'wordtag_label': '人物类_概念', 'length': 2, 'termid': '人物_cb_作者'}, {'item': '是', 'offset': 23, 'wordtag_label': '肯定词', 'length': 1, 'termid': '肯定否定词_cb_是'}, {'item': '余兼羽', 'offset': 24, 'wordtag_label': '人物类_实体', 'length': 3}]}]
@@ -207,8 +207,6 @@ def __init__(self,
207207
self._termtree = TermTree.from_dir(term_schema_path, term_data_path,
208208
self._linking)
209209

210-
self.crf = LinearChainCrf(len(self._tags_to_index), 100, with_start_stop_tag=False)
211-
self._viterbi_decoder = ViterbiDecoder(self.crf.transitions, False)
212210
self._usage = usage
213211
self._summary_num = 2
214212

@@ -510,6 +508,9 @@ def _construct_input_spec(self):
510508
paddle.static.InputSpec(shape=[None, None],
511509
dtype="int64",
512510
name="token_type_ids"), # token_type_ids
511+
paddle.static.InputSpec(shape=[None],
512+
dtype="int64",
513+
name="seq_len"), # seq_len
513514
]
514515

515516
def _construct_model(self, model):
@@ -518,9 +519,7 @@ def _construct_model(self, model):
518519
"""
519520
model_instance = ErnieCtmWordtagModel.from_pretrained(
520521
model,
521-
num_cls_label=4,
522-
num_tag=len(self._tags_to_index),
523-
ignore_index=self._tags_to_index["O"])
522+
num_tag=len(self._tags_to_index))
524523
config_keys = ErnieCtmWordtagModel.pretrained_init_configuration[
525524
self.model]
526525
self.kwargs.update(config_keys)
@@ -554,11 +553,10 @@ def _run_model(self, inputs):
554553
input_ids, token_type_ids, seq_len = batch
555554
self.input_handles[0].copy_from_cpu(input_ids.numpy())
556555
self.input_handles[1].copy_from_cpu(token_type_ids.numpy())
556+
self.input_handles[2].copy_from_cpu(seq_len.numpy())
557557
self.predictor.run()
558-
logits = self.output_handle[0].copy_to_cpu()
559-
score, pred_tags = self._viterbi_decoder(
560-
paddle.to_tensor(logits), seq_len)
561-
all_pred_tags.extend(pred_tags.numpy().tolist())
558+
pred_tags = self.output_handle[0].copy_to_cpu()
559+
all_pred_tags.extend(pred_tags.tolist())
562560
inputs['all_pred_tags'] = all_pred_tags
563561
return inputs
564562

paddlenlp/taskflow/named_entity_recognition.py

Lines changed: 11 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@
2424
from .utils import download_file
2525
from .utils import TermTree
2626
from .knowledge_mining import WordTagTask
27+
from .utils import Customization
2728

2829
usage = r"""
2930
from paddlenlp import Taskflow
@@ -35,8 +36,7 @@
3536
'''
3637
3738
ner = Taskflow("ner")
38-
ner(["热梅茶是一道以梅子为主要原料制作的茶饮",
39-
"《孤女》是2010年九州出版社出版的小说,作者是余兼羽"])
39+
ner(["热梅茶是一道以梅子为主要原料制作的茶饮", "《孤女》是2010年九州出版社出版的小说,作者是余兼羽"])
4040
'''
4141
[[('热梅茶', '饮食类_饮品'), ('是', '肯定词'), ('一道', '数量词'), ('以', '介词'), ('梅子', '饮食类'), ('为', '肯定词'), ('主要原料', '物体类'), ('制作', '场景事件'), ('的', '助词'), ('茶饮', '饮食类_饮品')], [('《', 'w'), ('孤女', '作品类_实体'), ('》', 'w'), ('是', '肯定词'), ('2010年', '时间类'), ('九州出版社', '组织机构类'), ('出版', '场景事件'), ('的', '助词'), ('小说', '作品类_概念'), (',', 'w'), ('作者', '人物类_概念'), ('是', '肯定词'), ('余兼羽', '人物类_实体')]]
4242
'''
@@ -56,6 +56,13 @@ class NERTask(WordTagTask):
5656

5757
def __init__(self, model, task, **kwargs):
5858
super().__init__(model=model, task=task, **kwargs)
59+
self._custom_vocab = self.kwargs[
60+
'custom_vocab'] if 'custom_vocab' in self.kwargs else None
61+
if self._custom_vocab:
62+
self._custom = Customization()
63+
self._custom.load_customization(self._custom_vocab)
64+
else:
65+
self._custom = None
5966

6067
def _decode(self, batch_texts, batch_pred_tags):
6168
batch_results = []
@@ -65,7 +72,8 @@ def _decode(self, batch_texts, batch_pred_tags):
6572
for index in batch_pred_tags[sent_index][self.summary_num:-1]
6673
]
6774
sent = batch_texts[sent_index]
68-
75+
if self._custom:
76+
self._custom.parse_customization(sent, tags, prefix=True)
6977
sent_out = []
7078
tags_out = []
7179
partial_word = ""

paddlenlp/taskflow/utils.py

Lines changed: 28 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -694,7 +694,7 @@ def load_customization(self, filename, sep=None):
694694
self.dictitem[phrase] = (tags, offset)
695695
self.ac.add_word(phrase)
696696

697-
def parse_customization(self, query, lac_tags):
697+
def parse_customization(self, query, lac_tags, prefix=False):
698698
"""Use custom vocab to modify the lac results"""
699699
if not self.ac:
700700
logging.warning("customization dict is not load")
@@ -706,16 +706,30 @@ def parse_customization(self, query, lac_tags):
706706
index = begin
707707

708708
tags, offsets = self.dictitem[phrase]
709-
for tag, offset in zip(tags, offsets):
710-
while index < begin + offset:
711-
if len(tag) == 0:
712-
lac_tags[index] = lac_tags[index][:-1] + 'I'
713-
else:
714-
lac_tags[index] = tag + "-I"
715-
index += 1
716-
717-
lac_tags[begin] = lac_tags[begin][:-1] + 'B'
718-
for offset in offsets:
719-
index = begin + offset
720-
if index < len(lac_tags):
721-
lac_tags[index] = lac_tags[index][:-1] + 'B'
709+
710+
if prefix:
711+
for tag, offset in zip(tags, offsets):
712+
while index < begin + offset:
713+
if len(tag) == 0:
714+
lac_tags[index] = "I" + lac_tags[index][1:]
715+
else:
716+
lac_tags[index] = "I-" + tag
717+
index += 1
718+
lac_tags[begin] = "B" + lac_tags[begin][1:]
719+
for offset in offsets:
720+
index = begin + offset
721+
if index < len(lac_tags):
722+
lac_tags[index] = "B" + lac_tags[index][1:]
723+
else:
724+
for tag, offset in zip(tags, offsets):
725+
while index < begin + offset:
726+
if len(tag) == 0:
727+
lac_tags[index] = lac_tags[index][:-1] + "I"
728+
else:
729+
lac_tags[index] = tag + "-I"
730+
index += 1
731+
lac_tags[begin] = lac_tags[begin][:-1] + "B"
732+
for offset in offsets:
733+
index = begin + offset
734+
if index < len(lac_tags):
735+
lac_tags[index] = lac_tags[index][:-1] + "B"

0 commit comments

Comments
 (0)