Skip to content

Commit 519f552

Browse files
authored
Merge pull request #1958 from linjieccc/op_path
Optimize task path for Taskflow
2 parents e99e41b + ca92e63 commit 519f552

File tree

6 files changed

+129
-99
lines changed

6 files changed

+129
-99
lines changed

docs/model_zoo/taskflow.md

Lines changed: 13 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -438,10 +438,10 @@ from paddlenlp import Taskflow
438438
>>> from paddlenlp import Taskflow
439439
>>> nptag = Taskflow("knowledge_mining", model="nptag")
440440
>>> nptag("糖醋排骨")
441-
>>> [{'text': '糖醋排骨', 'label': '菜品'}]
441+
[{'text': '糖醋排骨', 'label': '菜品'}]
442442

443-
nptag(["糖醋排骨", "红曲霉菌"])
444-
>>> [{'text': '糖醋排骨', 'label': '菜品'}, {'text': '红曲霉菌', 'label': '微生物'}]
443+
>>> nptag(["糖醋排骨", "红曲霉菌"])
444+
[{'text': '糖醋排骨', 'label': '菜品'}, {'text': '红曲霉菌', 'label': '微生物'}]
445445

446446
# 使用`linking`输出粗粒度类别标签`category`,即WordTag的词汇标签。
447447
>>> nptag = Taskflow("knowledge_mining", model="nptag", linking=True)
@@ -471,8 +471,7 @@ nptag(["糖醋排骨", "红曲霉菌"])
471471
[{'source': '遇到逆竟时,我们必须勇于面对,而且要愈挫愈勇。', 'target': '遇到逆境时,我们必须勇于面对,而且要愈挫愈勇。', 'errors': [{'position': 3, 'correction': {'': ''}}]}]
472472

473473
# 批量预测
474-
>>> corrector(['遇到逆竟时,我们必须勇于面对,而且要愈挫愈勇。',
475-
'人生就是如此,经过磨练才能让自己更加拙壮,才能使自己更加乐观。'])
474+
>>> corrector(['遇到逆竟时,我们必须勇于面对,而且要愈挫愈勇。', '人生就是如此,经过磨练才能让自己更加拙壮,才能使自己更加乐观。'])
476475
[{'source': '遇到逆竟时,我们必须勇于面对,而且要愈挫愈勇。', 'target': '遇到逆境时,我们必须勇于面对,而且要愈挫愈勇。', 'errors': [{'position': 3, 'correction': {'': ''}}]}, {'source': '人生就是如此,经过磨练才能让自己更加拙壮,才能使自己更加乐观。', 'target': '人生就是如此,经过磨练才能让自己更加茁壮,才能使自己更加乐观。', 'errors': [{'position': 18, 'correction': {'': ''}}]}]
477476
```
478477

@@ -628,16 +627,17 @@ nptag(["糖醋排骨", "红曲霉菌"])
628627

629628
| 任务名称 | 默认路径 | |
630629
| :----------------------------------------------------------: | :----------------------------------------------------------: | :----------------------------------------------------------: |
631-
| `Taskflow("word_segmentation", mode="base")` | `$HOME/.paddlenlp/taskflow/word_segmentation/lac` | [示例](https://github.com/PaddlePaddle/PaddleNLP/tree/develop/examples/lexical_analysis) |
632-
| `Taskflow("word_segmentation", mode="accurate")` | `$HOME/.paddlenlp/taskflow/word_segmentation/wordtag` | [示例](https://github.com/PaddlePaddle/PaddleNLP/tree/develop/examples/text_to_knowledge/ernie-ctm) |
633-
| `Taskflow("ner", mode="fast")` | `$HOME/.paddlenlp/taskflow/ner/lac` | [示例](https://github.com/PaddlePaddle/PaddleNLP/tree/develop/examples/lexical_analysis) |
634-
| `Taskflow("ner", mode="accurate")` | `$HOME/.paddlenlp/taskflow/ner/wordtag` | [示例](https://github.com/PaddlePaddle/PaddleNLP/tree/develop/examples/text_to_knowledge/ernie-ctm) |
635-
| `Taskflow("text_correction", model="csc-ernie-1.0")` | `$HOME/.paddlenlp/taskflow/text_correction/csc-ernie-1.0` | [示例](https://github.com/PaddlePaddle/PaddleNLP/tree/develop/examples/text_correction/ernie-csc) |
630+
| `Taskflow("word_segmentation", mode="base")` | `$HOME/.paddlenlp/taskflow/lac` | [示例](https://github.com/PaddlePaddle/PaddleNLP/tree/develop/examples/lexical_analysis) |
631+
| `Taskflow("word_segmentation", mode="accurate")` | `$HOME/.paddlenlp/taskflow/wordtag` | [示例](https://github.com/PaddlePaddle/PaddleNLP/tree/develop/examples/text_to_knowledge/ernie-ctm) |
632+
| `Taskflow("pos_tagging")` | `$HOME/.paddlenlp/taskflow/lac` | [示例](https://github.com/PaddlePaddle/PaddleNLP/tree/develop/examples/lexical_analysis) |
633+
| `Taskflow("ner", mode="fast")` | `$HOME/.paddlenlp/taskflow/lac` | [示例](https://github.com/PaddlePaddle/PaddleNLP/tree/develop/examples/lexical_analysis) |
634+
| `Taskflow("ner", mode="accurate")` | `$HOME/.paddlenlp/taskflow/wordtag` | [示例](https://github.com/PaddlePaddle/PaddleNLP/tree/develop/examples/text_to_knowledge/ernie-ctm) |
635+
| `Taskflow("text_correction", model="ernie-csc")` | `$HOME/.paddlenlp/taskflow/text_correction/ernie-csc` | [示例](https://github.com/PaddlePaddle/PaddleNLP/tree/develop/examples/text_correction/ernie-csc) |
636636
| `Taskflow("dependency_parsing", model="ddparser")` | `$HOME/.paddlenlp/taskflow/dependency_parsing/ddparser` | [示例](https://github.com/PaddlePaddle/PaddleNLP/tree/develop/examples/dependency_parsing/ddparser) |
637637
| `Taskflow("dependency_parsing", model="ddparser-ernie-1.0")` | `$HOME/.paddlenlp/taskflow/dependency_parsing/ddparser-ernie-1.0` | [示例](https://github.com/PaddlePaddle/PaddleNLP/tree/develop/examples/dependency_parsing/ddparser) |
638638
| `Taskflow("dependency_parsing", model="ddparser-ernie-gram-zh")` | `$HOME/.paddlenlp/taskflow/dependency_parsing/ddparser-ernie-gram-zh` | [示例](https://github.com/PaddlePaddle/PaddleNLP/tree/develop/examples/dependency_parsing/ddparser) |
639639
| `Taskflow("sentiment_analysis", model="skep_ernie_1.0_large_ch")` | `$HOME/.paddlenlp/taskflow/sentiment_analysis/skep_ernie_1.0_large_ch` | [示例](https://github.com/PaddlePaddle/PaddleNLP/tree/develop/examples/sentiment_analysis/skep) |
640-
| `Taskflow("knowledge_mining", model="wordtag")` | `$HOME/.paddlenlp/taskflow/knowledge_mining/wordtag` | [示例](https://github.com/PaddlePaddle/PaddleNLP/tree/develop/examples/text_to_knowledge/ernie-ctm) |
640+
| `Taskflow("knowledge_mining", model="wordtag")` | `$HOME/.paddlenlp/taskflow/wordtag` | [示例](https://github.com/PaddlePaddle/PaddleNLP/tree/develop/examples/text_to_knowledge/ernie-ctm) |
641641
| `Taskflow("knowledge_mining", model="nptag")` | `$HOME/.paddlenlp/taskflow/knowledge_mining/nptag` | [示例](https://github.com/PaddlePaddle/PaddleNLP/tree/develop/examples/text_to_knowledge/nptag) |
642642

643643
</div></details>
@@ -647,10 +647,10 @@ nptag(["糖醋排骨", "红曲霉菌"])
647647

648648
这里我们以命名实体识别`Taskflow("ner", mode="accurate")`为例,展示如何定制自己的模型。
649649

650-
调用`Taskflow`接口后,程序自动将相关文件下载到`$HOME/.paddlenlp/taskflow/ner/wordtag/`,该默认路径包含以下文件:
650+
调用`Taskflow`接口后,程序自动将相关文件下载到`$HOME/.paddlenlp/taskflow/wordtag/`,该默认路径包含以下文件:
651651

652652
```text
653-
$HOME/.paddlenlp/taskflow/ner/wordtag/
653+
$HOME/.paddlenlp/taskflow/wordtag/
654654
├── model_state.pdparams # 默认模型参数文件
655655
├── model_config.json # 默认模型配置文件
656656
└── tags.txt # 默认标签文件

paddlenlp/taskflow/dependency_parsing.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -84,7 +84,7 @@ class DDParserTask(Task):
8484

8585
resource_files_names = {
8686
"model_state": "model_state.pdparams",
87-
"word_vocab": "vocab.json",
87+
"word_vocab": "word_vocab.json",
8888
"rel_vocab": "rel_vocab.json",
8989
}
9090
resource_files_urls = {

paddlenlp/taskflow/task.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -33,9 +33,10 @@ class Task(metaclass=abc.ABCMeta):
3333
kwargs (dict, optional): Additional keyword arguments passed along to the specific task.
3434
"""
3535

36-
def __init__(self, model, task, **kwargs):
36+
def __init__(self, model, task, priority_path=None, **kwargs):
3737
self.model = model
3838
self.task = task
39+
self.priority_path = priority_path
3940
self.kwargs = kwargs
4041
self._usage = ""
4142
# The dygraph model instantce
@@ -50,6 +51,9 @@ def __init__(self, model, task, **kwargs):
5051
'task_flag'] if 'task_flag' in self.kwargs else self.model
5152
if 'task_path' in self.kwargs:
5253
self._task_path = self.kwargs['task_path']
54+
elif self.priority_path:
55+
self._task_path = os.path.join(self._home_path, "taskflow",
56+
self.priority_path)
5357
else:
5458
self._task_path = os.path.join(self._home_path, "taskflow",
5559
self.task, self.model)

paddlenlp/taskflow/taskflow.py

Lines changed: 94 additions & 75 deletions
Original file line numberDiff line numberDiff line change
@@ -39,33 +39,80 @@
3939
warnings.simplefilter(action='ignore', category=Warning, lineno=0, append=False)
4040

4141
TASKS = {
42+
'dependency_parsing': {
43+
"models": {
44+
"ddparser": {
45+
"task_class": DDParserTask,
46+
"task_flag": 'dependency_parsing-biaffine',
47+
},
48+
"ddparser-ernie-1.0": {
49+
"task_class": DDParserTask,
50+
"task_flag": 'dependency_parsing-ernie-1.0',
51+
},
52+
"ddparser-ernie-gram-zh": {
53+
"task_class": DDParserTask,
54+
"task_flag": 'dependency_parsing-ernie-gram-zh',
55+
},
56+
},
57+
"default": {
58+
"model": "ddparser",
59+
}
60+
},
61+
'dialogue': {
62+
"models": {
63+
"plato-mini": {
64+
"task_class": DialogueTask,
65+
"task_flag": "dialogue-plato-mini"
66+
},
67+
},
68+
"default": {
69+
"model": "plato-mini",
70+
}
71+
},
4272
"knowledge_mining": {
4373
"models": {
4474
"wordtag": {
4575
"task_class": WordTagTask,
4676
"task_flag": 'knowledge_mining-wordtag',
77+
"task_priority_path": "wordtag",
4778
},
4879
"nptag": {
4980
"task_class": NPTagTask,
5081
"task_flag": 'knowledge_mining-nptag',
5182
},
5283
},
5384
"default": {
54-
"model": "wordtag"
85+
"model": "wordtag",
86+
}
87+
},
88+
"lexical_analysis": {
89+
"models": {
90+
"lac": {
91+
"task_class": LacTask,
92+
"hidden_size": 128,
93+
"emb_dim": 128,
94+
"task_flag": 'lexical_analysis-gru_crf',
95+
"task_priority_path": "lac",
96+
}
97+
},
98+
"default": {
99+
"model": "lac"
55100
}
56101
},
57102
"ner": {
58103
"modes": {
59104
"accurate": {
60105
"task_class": NERWordTagTask,
61106
"task_flag": "ner-wordtag",
107+
"task_priority_path": "wordtag",
62108
"linking": False,
63109
},
64110
"fast": {
65111
"task_class": NERLACTask,
66112
"hidden_size": 128,
67113
"emb_dim": 128,
68114
"task_flag": "ner-lac",
115+
"task_priority_path": "lac",
69116
}
70117
},
71118
"default": {
@@ -77,69 +124,37 @@
77124
"gpt-cpm-large-cn": {
78125
"task_class": PoetryGenerationTask,
79126
"task_flag": 'poetry_generation-gpt-cpm-large-cn',
127+
"task_priority_path": "gpt-cpm-large-cn",
80128
},
81129
},
82130
"default": {
83131
"model": "gpt-cpm-large-cn",
84132
}
85133
},
86-
"question_answering": {
87-
"models": {
88-
"gpt-cpm-large-cn": {
89-
"task_class": QuestionAnsweringTask,
90-
"task_flag": 'question_answering-gpt-cpm-large-cn',
91-
},
92-
},
93-
"default": {
94-
"model": "gpt-cpm-large-cn",
95-
}
96-
},
97-
"lexical_analysis": {
134+
"pos_tagging": {
98135
"models": {
99136
"lac": {
100-
"task_class": LacTask,
137+
"task_class": POSTaggingTask,
101138
"hidden_size": 128,
102139
"emb_dim": 128,
103-
"task_flag": 'lexical_analysis-gru_crf',
140+
"task_flag": 'pos_tagging-gru_crf',
141+
"task_priority_path": "lac",
104142
}
105143
},
106144
"default": {
107145
"model": "lac"
108146
}
109147
},
110-
"word_segmentation": {
111-
"modes": {
112-
"fast": {
113-
"task_class": SegJiebaTask,
114-
"task_flag": "word_segmentation-jieba",
115-
},
116-
"base": {
117-
"task_class": SegLACTask,
118-
"hidden_size": 128,
119-
"emb_dim": 128,
120-
"task_flag": "word_segmentation-gru_crf",
121-
},
122-
"accurate": {
123-
"task_class": SegWordTagTask,
124-
"task_flag": "word_segmentation-wordtag",
125-
"linking": False,
126-
},
127-
},
128-
"default": {
129-
"mode": "base"
130-
}
131-
},
132-
"pos_tagging": {
148+
"question_answering": {
133149
"models": {
134-
"lac": {
135-
"task_class": POSTaggingTask,
136-
"hidden_size": 128,
137-
"emb_dim": 128,
138-
"task_flag": 'pos_tagging-gru_crf',
139-
}
150+
"gpt-cpm-large-cn": {
151+
"task_class": QuestionAnsweringTask,
152+
"task_flag": 'question_answering-gpt-cpm-large-cn',
153+
"task_priority_path": "gpt-cpm-large-cn",
154+
},
140155
},
141156
"default": {
142-
"model": "lac"
157+
"model": "gpt-cpm-large-cn",
143158
}
144159
},
145160
'sentiment_analysis': {
@@ -157,34 +172,15 @@
157172
"model": "bilstm"
158173
}
159174
},
160-
'dependency_parsing': {
161-
"models": {
162-
"ddparser": {
163-
"task_class": DDParserTask,
164-
"task_flag": 'dependency_parsing-biaffine',
165-
},
166-
"ddparser-ernie-1.0": {
167-
"task_class": DDParserTask,
168-
"task_flag": 'dependency_parsing-ernie-1.0',
169-
},
170-
"ddparser-ernie-gram-zh": {
171-
"task_class": DDParserTask,
172-
"task_flag": 'dependency_parsing-ernie-gram-zh',
173-
},
174-
},
175-
"default": {
176-
"model": "ddparser"
177-
}
178-
},
179175
'text_correction': {
180176
"models": {
181-
"csc-ernie-1.0": {
177+
"ernie-csc": {
182178
"task_class": CSCTask,
183-
"task_flag": "text_correction-csc-ernie-1.0"
179+
"task_flag": "text_correction-ernie-csc"
184180
},
185181
},
186182
"default": {
187-
"model": "csc-ernie-1.0"
183+
"model": "ernie-csc"
188184
}
189185
},
190186
'text_similarity': {
@@ -198,15 +194,28 @@
198194
"model": "simbert-base-chinese"
199195
}
200196
},
201-
'dialogue': {
202-
"models": {
203-
"plato-mini": {
204-
"task_class": DialogueTask,
205-
"task_flag": "dialogue-plato-mini"
197+
"word_segmentation": {
198+
"modes": {
199+
"fast": {
200+
"task_class": SegJiebaTask,
201+
"task_flag": "word_segmentation-jieba",
202+
},
203+
"base": {
204+
"task_class": SegLACTask,
205+
"hidden_size": 128,
206+
"emb_dim": 128,
207+
"task_flag": "word_segmentation-gru_crf",
208+
"task_priority_path": "lac",
209+
},
210+
"accurate": {
211+
"task_class": SegWordTagTask,
212+
"task_flag": "word_segmentation-wordtag",
213+
"task_priority_path": "wordtag",
214+
"linking": False,
206215
},
207216
},
208217
"default": {
209-
"model": "plato-mini"
218+
"mode": "base"
210219
}
211220
},
212221
}
@@ -247,6 +256,13 @@ def __init__(self, task, model=None, mode=None, device_id=0, **kwargs):
247256
)), "The {} name:{} is not in task:[{}]".format(tag, model, task)
248257
else:
249258
self.model = TASKS[task]['default'][ind_tag]
259+
260+
if "task_priority_path" in TASKS[self.task][tag][self.model]:
261+
self.priority_path = TASKS[self.task][tag][self.model][
262+
"task_priority_path"]
263+
else:
264+
self.priority_path = None
265+
250266
# Set the device for the task
251267
device = get_env_device()
252268
if device == 'cpu' or device_id == -1:
@@ -261,7 +277,10 @@ def __init__(self, task, model=None, mode=None, device_id=0, **kwargs):
261277
self.kwargs = kwargs
262278
task_class = TASKS[self.task][tag][self.model]['task_class']
263279
self.task_instance = task_class(
264-
model=self.model, task=self.task, **self.kwargs)
280+
model=self.model,
281+
task=self.task,
282+
priority_path=self.priority_path,
283+
**self.kwargs)
265284
task_list = TASKS.keys()
266285
Taskflow.task_list = task_list
267286

@@ -297,7 +316,7 @@ def from_segments(self, *inputs):
297316
return results
298317

299318
def interactive_mode(self, max_turn):
300-
with self.task_instance.interactive_mode(max_turn=3):
319+
with self.task_instance.interactive_mode(max_turn):
301320
while True:
302321
human = input("[Human]:").strip()
303322
if human.lower() == "exit":

paddlenlp/taskflow/text_correction.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -58,7 +58,7 @@
5858
5959
"""
6060

61-
TASK_MODEL_MAP = {"csc-ernie-1.0": "ernie-1.0"}
61+
TASK_MODEL_MAP = {"ernie-csc": "ernie-1.0"}
6262

6363

6464
class CSCTask(Task):
@@ -75,13 +75,13 @@ class CSCTask(Task):
7575
"pinyin_vocab": "pinyin_vocab.txt"
7676
}
7777
resource_files_urls = {
78-
"csc-ernie-1.0": {
78+
"ernie-csc": {
7979
"model_state": [
80-
"https://bj.bcebos.com/paddlenlp/taskflow/text_correction/csc-ernie-1.0/model_state.pdparams",
80+
"https://bj.bcebos.com/paddlenlp/taskflow/text_correction/ernie-csc/model_state.pdparams",
8181
"cdc53e7e3985ffc78fedcdf8e6dca6d2"
8282
],
8383
"pinyin_vocab": [
84-
"https://bj.bcebos.com/paddlenlp/taskflow/text_correction/csc-ernie-1.0/pinyin_vocab.txt",
84+
"https://bj.bcebos.com/paddlenlp/taskflow/text_correction/ernie-csc/pinyin_vocab.txt",
8585
"5599a8116b6016af573d08f8e686b4b2"
8686
],
8787
}

0 commit comments

Comments
 (0)