|
12 | 12 | # MERCHANTABILITY OR FIT FOR A PARTICULAR PURPOSE. |
13 | 13 | # See the Mulan PSL v2 for more details. |
14 | 14 |
|
15 | | -import requests |
16 | 15 | import os |
| 16 | +import re |
17 | 17 | import zipfile |
18 | 18 | import tarfile |
| 19 | +import requests |
19 | 20 | from pygmtools.dataset_config import dataset_cfg |
20 | 21 | from pathlib import Path |
21 | 22 | from xml.etree.ElementTree import Element |
|
30 | 31 | import scipy.io as sio |
31 | 32 | import glob |
32 | 33 | import random |
| 34 | +from urllib.parse import parse_qs, urlencode, urlparse |
33 | 35 | from pygmtools.utils import download |
34 | 36 |
|
35 | 37 |
|
|
101 | 103 | } |
102 | 104 |
|
103 | 105 |
|
| 106 | +def _resolve_google_drive_url(url): |
| 107 | + r""" |
| 108 | + Resolve the Google Drive virus warning page to the actual download URL. |
| 109 | + """ |
| 110 | + parsed_url = urlparse(url) |
| 111 | + if parsed_url.netloc in ('drive.google.com', 'www.drive.google.com'): |
| 112 | + query = parse_qs(parsed_url.query) |
| 113 | + if 'id' in query and query['id']: |
| 114 | + url = 'https://drive.google.com/uc?export=download&id={}'.format(query['id'][0]) |
| 115 | + |
| 116 | + session = requests.Session() |
| 117 | + response = session.get(url, timeout=60) |
| 118 | + response.raise_for_status() |
| 119 | + |
| 120 | + if response.headers.get('content-type', '').startswith(('application/', 'binary/')): |
| 121 | + return response.url |
| 122 | + |
| 123 | + action_match = re.search(r'<form[^>]*id="download-form"[^>]*action="([^"]+)"', response.text) |
| 124 | + if action_match is None: |
| 125 | + return response.url |
| 126 | + |
| 127 | + params = dict(re.findall(r'<input type="hidden" name="([^"]+)" value="([^"]*)"', response.text)) |
| 128 | + if not params: |
| 129 | + return response.url |
| 130 | + |
| 131 | + return '{}?{}'.format(action_match.group(1), urlencode(params)) |
| 132 | + |
| 133 | + |
| 134 | +def _resolve_download_urls(urls): |
| 135 | + resolved_urls = [] |
| 136 | + for candidate in urls: |
| 137 | + if 'drive.google.com' in candidate or 'drive.usercontent.google.com' in candidate: |
| 138 | + resolved_urls.append(_resolve_google_drive_url(candidate)) |
| 139 | + else: |
| 140 | + resolved_urls.append(candidate) |
| 141 | + return resolved_urls |
| 142 | + |
| 143 | + |
104 | 144 | class PascalVOC: |
105 | 145 | r""" |
106 | 146 | Download and preprocess **PascalVOC Keypoint** dataset. |
@@ -1027,9 +1067,9 @@ def __init__(self, sets, obj_resize, **ds_dict): |
1027 | 1067 | CLASSES = dataset_cfg.IMC_PT_SparseGM.CLASSES |
1028 | 1068 | ROOT_DIR_NPZ = dataset_cfg.IMC_PT_SparseGM.ROOT_DIR_NPZ |
1029 | 1069 | ROOT_DIR_IMG = dataset_cfg.IMC_PT_SparseGM.ROOT_DIR_IMG |
1030 | | - URL = ['https://huggingface.co/heatingma/pygmtools/resolve/main/IMC-PT-SparseGM.tar.gz', |
1031 | | - 'https://drive.google.com/u/0/uc?id=1bisri2Ip1Of3RsUA8OBrdH5oa6HlH3k-&export=download'] |
1032 | | - |
| 1070 | + URL = ['https://huggingface.co/datasets/esflfei/IMC-PT-SparseGM/resolve/main/IMC-PT-SparseGM.tar.gz', |
| 1071 | + 'https://drive.google.com/uc?export=download&id=1C3xl_eWaCG3lL2C3vP8Fpsck88xZOHtg'] |
| 1072 | + |
1033 | 1073 | if len(ds_dict.keys()) > 0: |
1034 | 1074 | if 'MAX_KPT_NUM' in ds_dict.keys(): |
1035 | 1075 | MAX_KPT_NUM = ds_dict['MAX_KPT_NUM'] |
@@ -1079,7 +1119,7 @@ def download(self, url=None, retries=15): |
1079 | 1119 | os.makedirs(dirs) |
1080 | 1120 | print('Downloading dataset IMC-PT-SparseGM...') |
1081 | 1121 | filename = 'data/IMC-PT-SparseGM.tar.gz' |
1082 | | - download(filename=filename, url=url, to_cache=False) |
| 1122 | + download(filename=filename, url=_resolve_download_urls(url), to_cache=False) |
1083 | 1123 | try: |
1084 | 1124 | tar = tarfile.open(filename, "r") |
1085 | 1125 | except tarfile.ReadError as err: |
@@ -1202,7 +1242,7 @@ def __init__(self, sets, obj_resize, **ds_dict): |
1202 | 1242 | CLS_SPLIT = dataset_cfg.CUB2011.CLASS_SPLIT |
1203 | 1243 | ROOT_DIR = dataset_cfg.CUB2011.ROOT_DIR |
1204 | 1244 | URL = ['https://huggingface.co/heatingma/pygmtools/resolve/main/CUB_200_2011.tgz', |
1205 | | - 'https://drive.google.com/u/0/uc?export=download&confirm=B8eu&id=1hbzc_P1FuxMkcabkgn9ZKinBwW683j45'] |
| 1245 | + 'https://drive.google.com/uc?export=download&id=1jUTTlx8vv7doQUHSCdl39laV-y0FBC61'] |
1206 | 1246 | if len(ds_dict.keys()) > 0: |
1207 | 1247 | if 'ROOT_DIR' in ds_dict.keys(): |
1208 | 1248 | ROOT_DIR = ds_dict['ROOT_DIR'] |
@@ -1276,7 +1316,7 @@ def download(self, url=None, retries=50): |
1276 | 1316 | os.makedirs(dirs) |
1277 | 1317 | print('Downloading dataset CUB2011...') |
1278 | 1318 | filename = 'data/CUB_200_2011.tgz' |
1279 | | - download(filename=filename, url=url, to_cache=False) |
| 1319 | + download(filename=filename, url=_resolve_download_urls(url), to_cache=False) |
1280 | 1320 | try: |
1281 | 1321 | tar = tarfile.open(filename, "r") |
1282 | 1322 | except tarfile.ReadError as err: |
|
0 commit comments