|
20 | 20 | import sys
|
21 | 21 | import os.path as osp
|
22 | 22 | import shutil
|
| 23 | +import json |
23 | 24 | import requests
|
24 | 25 | import hashlib
|
25 | 26 | import tarfile
|
26 | 27 | import zipfile
|
27 | 28 | import time
|
| 29 | +import uuid |
| 30 | +import threading |
28 | 31 | from collections import OrderedDict
|
| 32 | +from .env import DOWNLOAD_SERVER, SUCCESS_STATUS, FAILED_STATUS |
29 | 33 |
|
30 | 34 | try:
|
31 | 35 | from tqdm import tqdm
|
@@ -236,6 +240,15 @@ def _md5check(fullname, md5sum=None):
|
236 | 240 | return True
|
237 | 241 |
|
238 | 242 |
|
| 243 | +def _md5(text): |
| 244 | + """ |
| 245 | + Calculate the md5 value of the input text. |
| 246 | + """ |
| 247 | + |
| 248 | + md5code = hashlib.md5(text.encode()) |
| 249 | + return md5code.hexdigest() |
| 250 | + |
| 251 | + |
239 | 252 | def _decompress(fname):
|
240 | 253 | """
|
241 | 254 | Decompress for zip and tar file
|
@@ -336,3 +349,64 @@ def _is_a_single_dir(file_list):
|
336 | 349 | if file_name != new_file_list[i].split(os.sep)[0]:
|
337 | 350 | return False
|
338 | 351 | return True
|
| 352 | + |
| 353 | + |
| 354 | +class DownloaderCheck(threading.Thread): |
| 355 | + """ |
| 356 | + Check the resource applicability when downloading the models. |
| 357 | + """ |
| 358 | + |
| 359 | + def __init__(self, task, command="taskflow", addition=None): |
| 360 | + threading.Thread.__init__(self) |
| 361 | + self.command = command |
| 362 | + self.task = task |
| 363 | + self.addition = addition |
| 364 | + self.hash_flag = _md5(str(uuid.uuid1())[9:18]) + "-" + str( |
| 365 | + int(time.time())) |
| 366 | + |
| 367 | + def uri_path(self, server_url, api): |
| 368 | + srv = server_url |
| 369 | + if server_url.endswith('/'): |
| 370 | + srv = server_url[:-1] |
| 371 | + if api.startswith('/'): |
| 372 | + srv += api |
| 373 | + else: |
| 374 | + api = '/' + api |
| 375 | + srv += api |
| 376 | + return srv |
| 377 | + |
| 378 | + def request_check(self, task, command, addition): |
| 379 | + if task is None: |
| 380 | + return SUCCESS_STATUS |
| 381 | + payload = {'word': self.task} |
| 382 | + api_url = self.uri_path(DOWNLOAD_SERVER, 'search') |
| 383 | + cache_path = os.path.join("~") |
| 384 | + if os.path.exists(cache_path): |
| 385 | + extra = { |
| 386 | + "command": self.command, |
| 387 | + "mtime": os.stat(cache_path).st_mtime, |
| 388 | + "hub_name": self.hash_flag |
| 389 | + } |
| 390 | + else: |
| 391 | + extra = { |
| 392 | + "command": self.command, |
| 393 | + "mtime": time.strftime("%Y-%m-%d %H:%M:%S", time.localtime()), |
| 394 | + "hub_name": self.hash_flag |
| 395 | + } |
| 396 | + if addition is not None: |
| 397 | + extra.update({"addition": addition}) |
| 398 | + try: |
| 399 | + import paddle |
| 400 | + payload['hub_version'] = " " |
| 401 | + payload['paddle_version'] = paddle.__version__.split('-')[0] |
| 402 | + payload['extra'] = json.dumps(extra) |
| 403 | + r = requests.get(api_url, payload, timeout=1).json() |
| 404 | + if r.get("update_cache", 0) == 1: |
| 405 | + return SUCCESS_STATUS |
| 406 | + else: |
| 407 | + return FAILED_STATUS |
| 408 | + except Exception as err: |
| 409 | + return FAILED_STATUS |
| 410 | + |
| 411 | + def run(self): |
| 412 | + self.request_check(self.task, self.command, self.addition) |
0 commit comments