Skip to content

Commit 6cdd425

Browse files
committed
optimize taskflow download
1 parent 00af227 commit 6cdd425

File tree

5 files changed

+118
-88
lines changed

5 files changed

+118
-88
lines changed

docs/model_zoo/taskflow.md

Lines changed: 10 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -438,10 +438,10 @@ from paddlenlp import Taskflow
438438
>>> from paddlenlp import Taskflow
439439
>>> nptag = Taskflow("knowledge_mining", model="nptag")
440440
>>> nptag("糖醋排骨")
441-
>>> [{'text': '糖醋排骨', 'label': '菜品'}]
441+
[{'text': '糖醋排骨', 'label': '菜品'}]
442442

443-
nptag(["糖醋排骨", "红曲霉菌"])
444-
>>> [{'text': '糖醋排骨', 'label': '菜品'}, {'text': '红曲霉菌', 'label': '微生物'}]
443+
>>> nptag(["糖醋排骨", "红曲霉菌"])
444+
[{'text': '糖醋排骨', 'label': '菜品'}, {'text': '红曲霉菌', 'label': '微生物'}]
445445

446446
# 使用`linking`输出粗粒度类别标签`category`,即WordTag的词汇标签。
447447
>>> nptag = Taskflow("knowledge_mining", model="nptag", linking=True)
@@ -471,8 +471,7 @@ nptag(["糖醋排骨", "红曲霉菌"])
471471
[{'source': '遇到逆竟时,我们必须勇于面对,而且要愈挫愈勇。', 'target': '遇到逆境时,我们必须勇于面对,而且要愈挫愈勇。', 'errors': [{'position': 3, 'correction': {'': ''}}]}]
472472

473473
# 批量预测
474-
>>> corrector(['遇到逆竟时,我们必须勇于面对,而且要愈挫愈勇。',
475-
'人生就是如此,经过磨练才能让自己更加拙壮,才能使自己更加乐观。'])
474+
>>> corrector(['遇到逆竟时,我们必须勇于面对,而且要愈挫愈勇。', '人生就是如此,经过磨练才能让自己更加拙壮,才能使自己更加乐观。'])
476475
[{'source': '遇到逆竟时,我们必须勇于面对,而且要愈挫愈勇。', 'target': '遇到逆境时,我们必须勇于面对,而且要愈挫愈勇。', 'errors': [{'position': 3, 'correction': {'': ''}}]}, {'source': '人生就是如此,经过磨练才能让自己更加拙壮,才能使自己更加乐观。', 'target': '人生就是如此,经过磨练才能让自己更加茁壮,才能使自己更加乐观。', 'errors': [{'position': 18, 'correction': {'': ''}}]}]
477476
```
478477

@@ -628,16 +627,17 @@ nptag(["糖醋排骨", "红曲霉菌"])
628627

629628
| 任务名称 | 默认路径 | |
630629
| :----------------------------------------------------------: | :----------------------------------------------------------: | :----------------------------------------------------------: |
631-
| `Taskflow("word_segmentation", mode="base")` | `$HOME/.paddlenlp/taskflow/word_segmentation/lac` | [示例](https://github.com/PaddlePaddle/PaddleNLP/tree/develop/examples/lexical_analysis) |
632-
| `Taskflow("word_segmentation", mode="accurate")` | `$HOME/.paddlenlp/taskflow/word_segmentation/wordtag` | [示例](https://github.com/PaddlePaddle/PaddleNLP/tree/develop/examples/text_to_knowledge/ernie-ctm) |
633-
| `Taskflow("ner", mode="fast")` | `$HOME/.paddlenlp/taskflow/ner/lac` | [示例](https://github.com/PaddlePaddle/PaddleNLP/tree/develop/examples/lexical_analysis) |
634-
| `Taskflow("ner", mode="accurate")` | `$HOME/.paddlenlp/taskflow/ner/wordtag` | [示例](https://github.com/PaddlePaddle/PaddleNLP/tree/develop/examples/text_to_knowledge/ernie-ctm) |
630+
| `Taskflow("word_segmentation", mode="base")` | `$HOME/.paddlenlp/taskflow/lac` | [示例](https://github.com/PaddlePaddle/PaddleNLP/tree/develop/examples/lexical_analysis) |
631+
| `Taskflow("word_segmentation", mode="accurate")` | `$HOME/.paddlenlp/taskflow/wordtag` | [示例](https://github.com/PaddlePaddle/PaddleNLP/tree/develop/examples/text_to_knowledge/ernie-ctm) |
632+
| `Taskflow("pos_tagging")` | `$HOME/.paddlenlp/taskflow/lac` | [示例](https://github.com/PaddlePaddle/PaddleNLP/tree/develop/examples/lexical_analysis) |
633+
| `Taskflow("ner", mode="fast")` | `$HOME/.paddlenlp/taskflow/lac` | [示例](https://github.com/PaddlePaddle/PaddleNLP/tree/develop/examples/lexical_analysis) |
634+
| `Taskflow("ner", mode="accurate")` | `$HOME/.paddlenlp/taskflow/wordtag` | [示例](https://github.com/PaddlePaddle/PaddleNLP/tree/develop/examples/text_to_knowledge/ernie-ctm) |
635635
| `Taskflow("text_correction", model="csc-ernie-1.0")` | `$HOME/.paddlenlp/taskflow/text_correction/csc-ernie-1.0` | [示例](https://github.com/PaddlePaddle/PaddleNLP/tree/develop/examples/text_correction/ernie-csc) |
636636
| `Taskflow("dependency_parsing", model="ddparser")` | `$HOME/.paddlenlp/taskflow/dependency_parsing/ddparser` | [示例](https://github.com/PaddlePaddle/PaddleNLP/tree/develop/examples/dependency_parsing/ddparser) |
637637
| `Taskflow("dependency_parsing", model="ddparser-ernie-1.0")` | `$HOME/.paddlenlp/taskflow/dependency_parsing/ddparser-ernie-1.0` | [示例](https://github.com/PaddlePaddle/PaddleNLP/tree/develop/examples/dependency_parsing/ddparser) |
638638
| `Taskflow("dependency_parsing", model="ddparser-ernie-gram-zh")` | `$HOME/.paddlenlp/taskflow/dependency_parsing/ddparser-ernie-gram-zh` | [示例](https://github.com/PaddlePaddle/PaddleNLP/tree/develop/examples/dependency_parsing/ddparser) |
639639
| `Taskflow("sentiment_analysis", model="skep_ernie_1.0_large_ch")` | `$HOME/.paddlenlp/taskflow/sentiment_analysis/skep_ernie_1.0_large_ch` | [示例](https://github.com/PaddlePaddle/PaddleNLP/tree/develop/examples/sentiment_analysis/skep) |
640-
| `Taskflow("knowledge_mining", model="wordtag")` | `$HOME/.paddlenlp/taskflow/knowledge_mining/wordtag` | [示例](https://github.com/PaddlePaddle/PaddleNLP/tree/develop/examples/text_to_knowledge/ernie-ctm) |
640+
| `Taskflow("knowledge_mining", model="wordtag")` | `$HOME/.paddlenlp/taskflow/wordtag` | [示例](https://github.com/PaddlePaddle/PaddleNLP/tree/develop/examples/text_to_knowledge/ernie-ctm) |
641641
| `Taskflow("knowledge_mining", model="nptag")` | `$HOME/.paddlenlp/taskflow/knowledge_mining/nptag` | [示例](https://github.com/PaddlePaddle/PaddleNLP/tree/develop/examples/text_to_knowledge/nptag) |
642642

643643
</div></details>

paddlenlp/taskflow/dependency_parsing.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -84,7 +84,7 @@ class DDParserTask(Task):
8484

8585
resource_files_names = {
8686
"model_state": "model_state.pdparams",
87-
"word_vocab": "vocab.json",
87+
"word_vocab": "word_vocab.json",
8888
"rel_vocab": "rel_vocab.json",
8989
}
9090
resource_files_urls = {

paddlenlp/taskflow/task.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -33,9 +33,10 @@ class Task(metaclass=abc.ABCMeta):
3333
kwargs (dict, optional): Additional keyword arguments passed along to the specific task.
3434
"""
3535

36-
def __init__(self, model, task, **kwargs):
36+
def __init__(self, model, task, priority_path=None, **kwargs):
3737
self.model = model
3838
self.task = task
39+
self.priority_path = priority_path
3940
self.kwargs = kwargs
4041
self._usage = ""
4142
# The dygraph model instantce
@@ -50,6 +51,9 @@ def __init__(self, model, task, **kwargs):
5051
'task_flag'] if 'task_flag' in self.kwargs else self.model
5152
if 'task_path' in self.kwargs:
5253
self._task_path = self.kwargs['task_path']
54+
elif self.priority_path:
55+
self._task_path = os.path.join(self._home_path, "taskflow",
56+
self.priority_path)
5357
else:
5458
self._task_path = os.path.join(self._home_path, "taskflow",
5559
self.task, self.model)

paddlenlp/taskflow/taskflow.py

Lines changed: 90 additions & 71 deletions
Original file line numberDiff line numberDiff line change
@@ -39,33 +39,80 @@
3939
warnings.simplefilter(action='ignore', category=Warning, lineno=0, append=False)
4040

4141
TASKS = {
42+
'dependency_parsing': {
43+
"models": {
44+
"ddparser": {
45+
"task_class": DDParserTask,
46+
"task_flag": 'dependency_parsing-biaffine',
47+
},
48+
"ddparser-ernie-1.0": {
49+
"task_class": DDParserTask,
50+
"task_flag": 'dependency_parsing-ernie-1.0',
51+
},
52+
"ddparser-ernie-gram-zh": {
53+
"task_class": DDParserTask,
54+
"task_flag": 'dependency_parsing-ernie-gram-zh',
55+
},
56+
},
57+
"default": {
58+
"model": "ddparser",
59+
}
60+
},
61+
'dialogue': {
62+
"models": {
63+
"plato-mini": {
64+
"task_class": DialogueTask,
65+
"task_flag": "dialogue-plato-mini"
66+
},
67+
},
68+
"default": {
69+
"model": "plato-mini",
70+
}
71+
},
4272
"knowledge_mining": {
4373
"models": {
4474
"wordtag": {
4575
"task_class": WordTagTask,
4676
"task_flag": 'knowledge_mining-wordtag',
77+
"task_priority_path": "wordtag",
4778
},
4879
"nptag": {
4980
"task_class": NPTagTask,
5081
"task_flag": 'knowledge_mining-nptag',
5182
},
5283
},
5384
"default": {
54-
"model": "wordtag"
85+
"model": "wordtag",
86+
}
87+
},
88+
"lexical_analysis": {
89+
"models": {
90+
"lac": {
91+
"task_class": LacTask,
92+
"hidden_size": 128,
93+
"emb_dim": 128,
94+
"task_flag": 'lexical_analysis-gru_crf',
95+
"task_priority_path": "lac",
96+
}
97+
},
98+
"default": {
99+
"model": "lac"
55100
}
56101
},
57102
"ner": {
58103
"modes": {
59104
"accurate": {
60105
"task_class": NERWordTagTask,
61106
"task_flag": "ner-wordtag",
107+
"task_priority_path": "wordtag",
62108
"linking": False,
63109
},
64110
"fast": {
65111
"task_class": NERLACTask,
66112
"hidden_size": 128,
67113
"emb_dim": 128,
68114
"task_flag": "ner-lac",
115+
"task_priority_path": "lac",
69116
}
70117
},
71118
"default": {
@@ -77,69 +124,37 @@
77124
"gpt-cpm-large-cn": {
78125
"task_class": PoetryGenerationTask,
79126
"task_flag": 'poetry_generation-gpt-cpm-large-cn',
127+
"task_priority_path": "gpt-cpm-large-cn",
80128
},
81129
},
82130
"default": {
83131
"model": "gpt-cpm-large-cn",
84132
}
85133
},
86-
"question_answering": {
87-
"models": {
88-
"gpt-cpm-large-cn": {
89-
"task_class": QuestionAnsweringTask,
90-
"task_flag": 'question_answering-gpt-cpm-large-cn',
91-
},
92-
},
93-
"default": {
94-
"model": "gpt-cpm-large-cn",
95-
}
96-
},
97-
"lexical_analysis": {
134+
"pos_tagging": {
98135
"models": {
99136
"lac": {
100-
"task_class": LacTask,
137+
"task_class": POSTaggingTask,
101138
"hidden_size": 128,
102139
"emb_dim": 128,
103-
"task_flag": 'lexical_analysis-gru_crf',
140+
"task_flag": 'pos_tagging-gru_crf',
141+
"task_priority_path": "lac",
104142
}
105143
},
106144
"default": {
107145
"model": "lac"
108146
}
109147
},
110-
"word_segmentation": {
111-
"modes": {
112-
"fast": {
113-
"task_class": SegJiebaTask,
114-
"task_flag": "word_segmentation-jieba",
115-
},
116-
"base": {
117-
"task_class": SegLACTask,
118-
"hidden_size": 128,
119-
"emb_dim": 128,
120-
"task_flag": "word_segmentation-gru_crf",
121-
},
122-
"accurate": {
123-
"task_class": SegWordTagTask,
124-
"task_flag": "word_segmentation-wordtag",
125-
"linking": False,
126-
},
127-
},
128-
"default": {
129-
"mode": "base"
130-
}
131-
},
132-
"pos_tagging": {
148+
"question_answering": {
133149
"models": {
134-
"lac": {
135-
"task_class": POSTaggingTask,
136-
"hidden_size": 128,
137-
"emb_dim": 128,
138-
"task_flag": 'pos_tagging-gru_crf',
139-
}
150+
"gpt-cpm-large-cn": {
151+
"task_class": QuestionAnsweringTask,
152+
"task_flag": 'question_answering-gpt-cpm-large-cn',
153+
"task_priority_path": "gpt-cpm-large-cn",
154+
},
140155
},
141156
"default": {
142-
"model": "lac"
157+
"model": "gpt-cpm-large-cn",
143158
}
144159
},
145160
'sentiment_analysis': {
@@ -157,25 +172,6 @@
157172
"model": "bilstm"
158173
}
159174
},
160-
'dependency_parsing': {
161-
"models": {
162-
"ddparser": {
163-
"task_class": DDParserTask,
164-
"task_flag": 'dependency_parsing-biaffine',
165-
},
166-
"ddparser-ernie-1.0": {
167-
"task_class": DDParserTask,
168-
"task_flag": 'dependency_parsing-ernie-1.0',
169-
},
170-
"ddparser-ernie-gram-zh": {
171-
"task_class": DDParserTask,
172-
"task_flag": 'dependency_parsing-ernie-gram-zh',
173-
},
174-
},
175-
"default": {
176-
"model": "ddparser"
177-
}
178-
},
179175
'text_correction': {
180176
"models": {
181177
"csc-ernie-1.0": {
@@ -198,15 +194,28 @@
198194
"model": "simbert-base-chinese"
199195
}
200196
},
201-
'dialogue': {
202-
"models": {
203-
"plato-mini": {
204-
"task_class": DialogueTask,
205-
"task_flag": "dialogue-plato-mini"
197+
"word_segmentation": {
198+
"modes": {
199+
"fast": {
200+
"task_class": SegJiebaTask,
201+
"task_flag": "word_segmentation-jieba",
202+
},
203+
"base": {
204+
"task_class": SegLACTask,
205+
"hidden_size": 128,
206+
"emb_dim": 128,
207+
"task_flag": "word_segmentation-gru_crf",
208+
"task_priority_path": "lac",
209+
},
210+
"accurate": {
211+
"task_class": SegWordTagTask,
212+
"task_flag": "word_segmentation-wordtag",
213+
"task_priority_path": "wordtag",
214+
"linking": False,
206215
},
207216
},
208217
"default": {
209-
"model": "plato-mini"
218+
"mode": "base"
210219
}
211220
},
212221
}
@@ -247,6 +256,13 @@ def __init__(self, task, model=None, mode=None, device_id=0, **kwargs):
247256
)), "The {} name:{} is not in task:[{}]".format(tag, model, task)
248257
else:
249258
self.model = TASKS[task]['default'][ind_tag]
259+
260+
if "task_priority_path" in TASKS[self.task][tag][self.model]:
261+
self.priority_path = TASKS[self.task][tag][self.model][
262+
"task_priority_path"]
263+
else:
264+
self.priority_path = None
265+
250266
# Set the device for the task
251267
device = get_env_device()
252268
if device == 'cpu' or device_id == -1:
@@ -261,7 +277,10 @@ def __init__(self, task, model=None, mode=None, device_id=0, **kwargs):
261277
self.kwargs = kwargs
262278
task_class = TASKS[self.task][tag][self.model]['task_class']
263279
self.task_instance = task_class(
264-
model=self.model, task=self.task, **self.kwargs)
280+
model=self.model,
281+
task=self.task,
282+
priority_path=self.priority_path,
283+
**self.kwargs)
265284
task_list = TASKS.keys()
266285
Taskflow.task_list = task_list
267286

paddlenlp/utils/downloader.py

Lines changed: 12 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -362,8 +362,7 @@ def __init__(self, task, command="taskflow", addition=None):
362362
self.command = command
363363
self.task = task
364364
self.addition = addition
365-
self.hash_flag = _md5(str(uuid.uuid1())[9:18]) + "-" + str(
366-
int(time.time()))
365+
self._initialize()
367366

368367
def uri_path(self, server_url, api):
369368
srv = server_url
@@ -376,30 +375,38 @@ def uri_path(self, server_url, api):
376375
srv += api
377376
return srv
378377

378+
def _initialize(self):
379+
etime = str(int(time.time()))
380+
self.cache_info = _md5(str(uuid.uuid1())[-12:])
381+
self.hash_flag = _md5(str(uuid.uuid1())[9:18]) + "-" + etime
382+
379383
def request_check(self, task, command, addition):
380384
if task is None:
381385
return SUCCESS_STATUS
382386
payload = {'word': self.task}
383-
api_url = self.uri_path(DOWNLOAD_SERVER, 'search')
387+
api_url = self.uri_path(DOWNLOAD_SERVER, 'stat')
384388
cache_path = os.path.join("~")
385389
if os.path.exists(cache_path):
386390
extra = {
387391
"command": self.command,
388392
"mtime": os.stat(cache_path).st_mtime,
389-
"hub_name": self.hash_flag
393+
"hub_name": self.hash_flag,
394+
"cache_info": self.cache_info
390395
}
391396
else:
392397
extra = {
393398
"command": self.command,
394399
"mtime": time.strftime("%Y-%m-%d %H:%M:%S", time.localtime()),
395-
"hub_name": self.hash_flag
400+
"hub_name": self.hash_flag,
401+
"cache_info": self.cache_info
396402
}
397403
if addition is not None:
398404
extra.update({"addition": addition})
399405
try:
400406
import paddle
401407
payload['hub_version'] = " "
402408
payload['paddle_version'] = paddle.__version__.split('-')[0]
409+
payload['from'] = 'ppnlp'
403410
payload['extra'] = json.dumps(extra)
404411
r = requests.get(api_url, payload, timeout=1).json()
405412
if r.get("update_cache", 0) == 1:

0 commit comments

Comments
 (0)