Skip to content

Commit e51e699

Browse files
authored
Merge pull request #1166 from wawltor/taskflow_download_check
update the download check for the Taskflow
2 parents 1b4821b + 6feb45a commit e51e699

12 files changed

+59
-100
lines changed

paddlenlp/taskflow/dependency_parsing.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -113,7 +113,7 @@ def __init__(self,
113113
ddparser, ddparser-ernie-1.0 and ddoarser-ernie-gram-zh")
114114
word_vocab_path = download_file(
115115
self._task_path, self.model + os.path.sep + "word_vocab.json",
116-
URLS[self.model][0], URLS[self.model][1], self.model)
116+
URLS[self.model][0], URLS[self.model][1])
117117
rel_vocab_path = download_file(
118118
self._task_path, self.model + os.path.sep + "rel_vocab.json",
119119
URLS[self.model][0], URLS[self.model][1])

paddlenlp/taskflow/knowledge_mining.py

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -143,13 +143,11 @@ class WordTagTask(Task):
143143
def __init__(self, model, task, **kwargs):
144144
super().__init__(model=model, task=task, **kwargs)
145145
self._static_mode = False
146-
self._log_name = self.kwargs[
147-
'log_name'] if 'log_name' in self.kwargs else 'wordtag'
148146
self._linking = self.kwargs[
149147
'linking'] if 'linking' in self.kwargs else False
150-
term_schema_path = download_file(
151-
self._task_path, "termtree_type.csv", URLS['termtree_type'][0],
152-
URLS['termtree_type'][1], self._log_name)
148+
term_schema_path = download_file(self._task_path, "termtree_type.csv",
149+
URLS['termtree_type'][0],
150+
URLS['termtree_type'][1])
153151
term_data_path = download_file(self._task_path, "TermTree.V1.0",
154152
URLS['TermTree.V1.0'][0],
155153
URLS['TermTree.V1.0'][1])

paddlenlp/taskflow/lexical_analysis.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -96,7 +96,7 @@ def __init__(self, task, model, **kwargs):
9696
self._usage = usage
9797
word_dict_path = download_file(
9898
self._task_path, "lac_params" + os.path.sep + "word.dic",
99-
URLS['lac_params'][0], URLS['lac_params'][1], 'lexical_analysis')
99+
URLS['lac_params'][0], URLS['lac_params'][1])
100100
tag_dict_path = download_file(
101101
self._task_path, "lac_params" + os.path.sep + "tag.dic",
102102
URLS['lac_params'][0], URLS['lac_params'][1])

paddlenlp/taskflow/poetry_generation.py

Lines changed: 0 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -59,12 +59,3 @@ class PoetryGenerationTask(TextGenerationTask):
5959

6060
def __init__(self, task, model, **kwargs):
6161
super().__init__(task=task, model=model, **kwargs)
62-
if self._static_mode:
63-
download_file(
64-
self._task_path, "static" + os.path.sep + "inference.pdiparams",
65-
URLS[self.model][0], URLS[self.model][1], "poetry_generation")
66-
self._get_inference_model()
67-
else:
68-
self._construct_model(model)
69-
self._construct_tokenizer(model)
70-
self.kwargs['generation_task'] = 'poetry_generation'

paddlenlp/taskflow/pos_tagging.py

Lines changed: 0 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -24,13 +24,6 @@
2424
from .utils import download_file
2525
from .lexical_analysis import load_vocab, LacTask
2626

27-
URLS = {
28-
"pos_tagging_params": [
29-
"https://paddlenlp.bj.bcebos.com/taskflow/lexical_analysis/lac/lac_params.tar.gz",
30-
'ee9a3eaba5f74105410410e3c5b28fbc'
31-
],
32-
}
33-
3427
usage = r"""
3528
from paddlenlp import Taskflow
3629
@@ -58,29 +51,6 @@ class POSTaggingTask(LacTask):
5851

5952
def __init__(self, task, model, **kwargs):
6053
super().__init__(task=task, model=model, **kwargs)
61-
self._static_mode = False
62-
self._usage = usage
63-
word_dict_path = download_file(
64-
self._task_path, "lac_params" + os.path.sep + "word.dic",
65-
URLS['pos_tagging_params'][0], URLS['pos_tagging_params'][1],
66-
'pos_tagging')
67-
tag_dict_path = download_file(
68-
self._task_path, "lac_params" + os.path.sep + "tag.dic",
69-
URLS['pos_tagging_params'][0], URLS['pos_tagging_params'][1])
70-
q2b_dict_path = download_file(
71-
self._task_path, "lac_params" + os.path.sep + "q2b.dic",
72-
URLS['pos_tagging_params'][0], URLS['pos_tagging_params'][1])
73-
self._word_vocab = load_vocab(word_dict_path)
74-
self._tag_vocab = load_vocab(tag_dict_path)
75-
self._q2b_vocab = load_vocab(q2b_dict_path)
76-
self._id2word_dict = dict(
77-
zip(self._word_vocab.values(), self._word_vocab.keys()))
78-
self._id2tag_dict = dict(
79-
zip(self._tag_vocab.values(), self._tag_vocab.keys()))
80-
if self._static_mode:
81-
self._get_inference_model()
82-
else:
83-
self._construct_model(model)
8454

8555
def _postprocess(self, inputs):
8656
"""

paddlenlp/taskflow/question_answering.py

Lines changed: 0 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -59,12 +59,3 @@ class QuestionAnsweringTask(TextGenerationTask):
5959

6060
def __init__(self, task, model, **kwargs):
6161
super().__init__(task=task, model=model, **kwargs)
62-
if self._static_mode:
63-
download_file(
64-
self._task_path, "static" + os.path.sep + "inference.pdiparams",
65-
URLS[self.model][0], URLS[self.model][1], "question_answering")
66-
self._get_inference_model()
67-
else:
68-
self._construct_model(model)
69-
self._construct_tokenizer(model)
70-
self.kwargs['generation_task'] = 'question_answering'

paddlenlp/taskflow/sentiment_analysis.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -234,7 +234,7 @@ def _construct_model(self, model):
234234
model_instance = SkepSequenceModel.from_pretrained(
235235
model, num_classes=len(self._label_map))
236236
model_path = download_file(self._task_path, model + ".pdparams",
237-
URLS[model][0], URLS[model][1], model)
237+
URLS[model][0], URLS[model][1])
238238
state_dict = paddle.load(model_path)
239239
model_instance.set_state_dict(state_dict)
240240
self._model = model_instance

paddlenlp/taskflow/task.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@
1919
import paddle
2020
from ..utils.env import PPNLP_HOME
2121
from ..utils.log import logger
22-
from .utils import static_mode_guard, dygraph_mode_guard
22+
from .utils import download_check, static_mode_guard, dygraph_mode_guard
2323

2424

2525
class Task(metaclass=abc.ABCMeta):
@@ -44,6 +44,9 @@ def __init__(self, model, task, **kwargs):
4444
self._config = None
4545
self._task_path = os.path.join(PPNLP_HOME, "taskflow", self.task,
4646
self.model)
47+
self._task_flag = self.kwargs[
48+
'task_flag'] if 'task_flag' in self.kwargs else self.model
49+
download_check(self._task_flag)
4750

4851
@abstractmethod
4952
def _construct_model(self, model):

paddlenlp/taskflow/taskflow.py

Lines changed: 22 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,7 @@
3636
"models": {
3737
"wordtag": {
3838
"task_class": WordTagTask,
39-
"log_name": 'knowledge_mining_wordtag',
39+
"task_flag": 'knowledge_mining-wordtag',
4040
"linking": True,
4141
}
4242
},
@@ -48,7 +48,7 @@
4848
"models": {
4949
"wordtag": {
5050
"task_class": NERTask,
51-
"log_name": 'ner_wordtag',
51+
"task_flag": 'ner-wordtag',
5252
"linking": False,
5353
}
5454
},
@@ -60,6 +60,7 @@
6060
"models": {
6161
"gpt-cpm-large-cn": {
6262
"task_class": PoetryGenerationTask,
63+
"task_flag": 'poetry_generation-gpt-cpm-large-cn',
6364
},
6465
},
6566
"default": {
@@ -70,6 +71,7 @@
7071
"models": {
7172
"gpt-cpm-large-cn": {
7273
"task_class": QuestionAnsweringTask,
74+
"task_flag": 'question_answering-gpt-cpm-large-cn',
7375
},
7476
},
7577
"default": {
@@ -81,7 +83,8 @@
8183
"lac": {
8284
"task_class": LacTask,
8385
"hidden_size": 128,
84-
"emb_dim": 128
86+
"emb_dim": 128,
87+
"task_flag": 'lexical_analysis-gru_crf',
8588
}
8689
},
8790
"default": {
@@ -93,7 +96,8 @@
9396
"lac": {
9497
"task_class": WordSegmentationTask,
9598
"hidden_size": 128,
96-
"emb_dim": 128
99+
"emb_dim": 128,
100+
"task_flag": 'word_segmentation-gru_crf',
97101
}
98102
},
99103
"default": {
@@ -105,7 +109,8 @@
105109
"lac": {
106110
"task_class": POSTaggingTask,
107111
"hidden_size": 128,
108-
"emb_dim": 128
112+
"emb_dim": 128,
113+
"task_flag": 'pos_tagging-gru_crf',
109114
}
110115
},
111116
"default": {
@@ -115,10 +120,12 @@
115120
'sentiment_analysis': {
116121
"models": {
117122
"bilstm": {
118-
"task_class": SentaTask
123+
"task_class": SentaTask,
124+
"task_flag": 'sentiment_analysis-bilstm',
119125
},
120126
"skep_ernie_1.0_large_ch": {
121-
"task_class": SkepTask
127+
"task_class": SkepTask,
128+
"task_flag": 'sentiment_analysis-skep_ernie_1.0_large_ch',
122129
}
123130
},
124131
"default": {
@@ -128,13 +135,16 @@
128135
'dependency_parsing': {
129136
"models": {
130137
"ddparser": {
131-
"task_class": DDParserTask
138+
"task_class": DDParserTask,
139+
"task_flag": 'dependency_parsing-biaffine',
132140
},
133141
"ddparser-ernie-1.0": {
134-
"task_class": DDParserTask
142+
"task_class": DDParserTask,
143+
"task_flag": 'dependency_parsing-ernie-1.0',
135144
},
136145
"ddparser-ernie-gram-zh": {
137-
"task_class": DDParserTask
146+
"task_class": DDParserTask,
147+
"task_flag": 'dependency_parsing-ernie-gram-zh',
138148
},
139149
},
140150
"default": {
@@ -144,7 +154,8 @@
144154
'text_correction': {
145155
"models": {
146156
"csc-ernie-1.0": {
147-
"task_class": CSCTask
157+
"task_class": CSCTask,
158+
"task_flag": "text_correction-csc-ernie-1.0"
148159
},
149160
},
150161
"default": {

paddlenlp/taskflow/text_generation.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -55,6 +55,15 @@ def __init__(self, task, model, **kwargs):
5555
super().__init__(task=task, model=model, **kwargs)
5656
self._static_mode = True
5757
self._usage = usage
58+
if self._static_mode:
59+
download_file(self._task_path,
60+
"static" + os.path.sep + "inference.pdiparams",
61+
URLS[self.model][0], URLS[self.model][1])
62+
self._get_inference_model()
63+
else:
64+
self._construct_model(model)
65+
self._construct_tokenizer(model)
66+
self.kwargs['generation_task'] = task
5867

5968
def _construct_input_spec(self):
6069
"""

0 commit comments

Comments
 (0)