Skip to content

Commit 147a943

Browse files
authored
add the download check for the taskflow (#758)
* add the download check for tht taskflow * update the taskflow download message * update the user generate for the paddlenlp
1 parent edde6aa commit 147a943

File tree

5 files changed

+89
-11
lines changed

5 files changed

+89
-11
lines changed

paddlenlp/taskflow/sentiment_analysis.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -104,19 +104,19 @@ def _construct_model(self, model):
104104
network = self.kwargs['network']
105105
if network == "bow":
106106
model = BoWModel(vocab_size, num_classes, padding_idx=pad_token_id)
107-
model_full_name = download_file(self.model, "senta_bow.pdparams",
108-
URLS['senta_bow'][0],
109-
URLS['senta_bow'][1])
107+
model_full_name = download_file(
108+
self.model, "senta_bow.pdparams", URLS['senta_bow'][0],
109+
URLS['senta_bow'][1], "sentiment_analysis")
110110
elif network == "lstm":
111111
model = LSTMModel(
112112
vocab_size,
113113
num_classes,
114114
direction='forward',
115115
padding_idx=pad_token_id,
116116
pooling_type='max')
117-
model_full_name = download_file(self.model, "senta_lstm.pdparams",
118-
URLS['senta_lstm'][0],
119-
URLS['senta_lstm'][1])
117+
model_full_name = download_file(
118+
self.model, "senta_lstm.pdparams", URLS['senta_lstm'][0],
119+
URLS['senta_lstm'][1], "sentiment_analysis")
120120
else:
121121
raise ValueError(
122122
"Unknown network: {}, it must be one of bow, lstm.".format(

paddlenlp/taskflow/text2knowledge.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -147,9 +147,9 @@ class WordTagTask(Task):
147147

148148
def __init__(self, model, task, **kwargs):
149149
super().__init__(model=model, task=task, **kwargs)
150-
term_schema_path = download_file(self.model, "termtree_type.csv",
151-
URLS['termtree_type'][0],
152-
URLS['termtree_type'][1])
150+
term_schema_path = download_file(
151+
self.model, "termtree_type.csv", URLS['termtree_type'][0],
152+
URLS['termtree_type'][1], "text2knowledge")
153153
term_data_path = download_file(self.model, "TermTree.V1.0",
154154
URLS['TermTree.V1.0'][0],
155155
URLS['TermTree.V1.0'][1])

paddlenlp/taskflow/utils.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@
1515

1616
import os
1717
from paddle.dataset.common import md5file
18-
from ..utils.downloader import get_path_from_url
18+
from ..utils.downloader import get_path_from_url, DownloaderCheck
1919
from ..utils.env import MODEL_HOME
2020

2121
DOC_FORMAT = r"""
@@ -24,7 +24,7 @@
2424
"""
2525

2626

27-
def download_file(save_dir, filename, url, md5=None):
27+
def download_file(save_dir, filename, url, md5=None, task=None):
2828
"""
2929
Download the file from the url to specified directory.
3030
Check md5 value when the file is exists, if the md5 value is the same as the existed file, just use
@@ -36,6 +36,7 @@ def download_file(save_dir, filename, url, md5=None):
3636
url(string): The url downling the file.
3737
md5(string, optional): The md5 value that checking the version downloaded.
3838
"""
39+
DownloaderCheck(task).start()
3940
default_root = os.path.join(MODEL_HOME, save_dir)
4041
fullname = os.path.join(default_root, filename)
4142
if os.path.exists(fullname):

paddlenlp/utils/downloader.py

Lines changed: 74 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,12 +20,16 @@
2020
import sys
2121
import os.path as osp
2222
import shutil
23+
import json
2324
import requests
2425
import hashlib
2526
import tarfile
2627
import zipfile
2728
import time
29+
import uuid
30+
import threading
2831
from collections import OrderedDict
32+
from .env import DOWNLOAD_SERVER, SUCCESS_STATUS, FAILED_STATUS
2933

3034
try:
3135
from tqdm import tqdm
@@ -236,6 +240,15 @@ def _md5check(fullname, md5sum=None):
236240
return True
237241

238242

243+
def _md5(text):
244+
"""
245+
Calculate the md5 value of the input text.
246+
"""
247+
248+
md5code = hashlib.md5(text.encode())
249+
return md5code.hexdigest()
250+
251+
239252
def _decompress(fname):
240253
"""
241254
Decompress for zip and tar file
@@ -336,3 +349,64 @@ def _is_a_single_dir(file_list):
336349
if file_name != new_file_list[i].split(os.sep)[0]:
337350
return False
338351
return True
352+
353+
354+
class DownloaderCheck(threading.Thread):
355+
"""
356+
Check the resource applicability when downloading the models.
357+
"""
358+
359+
def __init__(self, task, command="taskflow", addition=None):
360+
threading.Thread.__init__(self)
361+
self.command = command
362+
self.task = task
363+
self.addition = addition
364+
self.hash_flag = _md5(str(uuid.uuid1())[9:18]) + "-" + str(
365+
int(time.time()))
366+
367+
def uri_path(self, server_url, api):
368+
srv = server_url
369+
if server_url.endswith('/'):
370+
srv = server_url[:-1]
371+
if api.startswith('/'):
372+
srv += api
373+
else:
374+
api = '/' + api
375+
srv += api
376+
return srv
377+
378+
def request_check(self, task, command, addition):
379+
if task is None:
380+
return SUCCESS_STATUS
381+
payload = {'word': self.task}
382+
api_url = self.uri_path(DOWNLOAD_SERVER, 'search')
383+
cache_path = os.path.join("~")
384+
if os.path.exists(cache_path):
385+
extra = {
386+
"command": self.command,
387+
"mtime": os.stat(cache_path).st_mtime,
388+
"hub_name": self.hash_flag
389+
}
390+
else:
391+
extra = {
392+
"command": self.command,
393+
"mtime": time.strftime("%Y-%m-%d %H:%M:%S", time.localtime()),
394+
"hub_name": self.hash_flag
395+
}
396+
if addition is not None:
397+
extra.update({"addition": addition})
398+
try:
399+
import paddle
400+
payload['hub_version'] = " "
401+
payload['paddle_version'] = paddle.__version__.split('-')[0]
402+
payload['extra'] = json.dumps(extra)
403+
r = requests.get(api_url, payload, timeout=1).json()
404+
if r.get("update_cache", 0) == 1:
405+
return SUCCESS_STATUS
406+
else:
407+
return FAILED_STATUS
408+
except Exception as err:
409+
return FAILED_STATUS
410+
411+
def run(self):
412+
self.request_check(self.task, self.command, self.addition)

paddlenlp/utils/env.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -51,3 +51,6 @@ def _get_sub_home(directory, parent_home=_get_ppnlp_home()):
5151
PPNLP_HOME = _get_ppnlp_home()
5252
MODEL_HOME = _get_sub_home('models')
5353
DATA_HOME = _get_sub_home('datasets')
54+
DOWNLOAD_SERVER = "http://paddlepaddle.org.cn/paddlehub"
55+
FAILED_STATUS = -1
56+
SUCCESS_STATUS = 0

0 commit comments

Comments
 (0)