Skip to content

Commit f718a89

Browse files
committed
fix & update dataset links
1 parent 569bc35 commit f718a89

File tree

1 file changed

+47
-7
lines changed

1 file changed

+47
-7
lines changed

pygmtools/dataset.py

Lines changed: 47 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -12,10 +12,11 @@
1212
# MERCHANTABILITY OR FIT FOR A PARTICULAR PURPOSE.
1313
# See the Mulan PSL v2 for more details.
1414

15-
import requests
1615
import os
16+
import re
1717
import zipfile
1818
import tarfile
19+
import requests
1920
from pygmtools.dataset_config import dataset_cfg
2021
from pathlib import Path
2122
from xml.etree.ElementTree import Element
@@ -30,6 +31,7 @@
3031
import scipy.io as sio
3132
import glob
3233
import random
34+
from urllib.parse import parse_qs, urlencode, urlparse
3335
from pygmtools.utils import download
3436

3537

@@ -101,6 +103,44 @@
101103
}
102104

103105

106+
def _resolve_google_drive_url(url):
107+
r"""
108+
Resolve the Google Drive virus warning page to the actual download URL.
109+
"""
110+
parsed_url = urlparse(url)
111+
if parsed_url.netloc in ('drive.google.com', 'www.drive.google.com'):
112+
query = parse_qs(parsed_url.query)
113+
if 'id' in query and query['id']:
114+
url = 'https://drive.google.com/uc?export=download&id={}'.format(query['id'][0])
115+
116+
session = requests.Session()
117+
response = session.get(url, timeout=60)
118+
response.raise_for_status()
119+
120+
if response.headers.get('content-type', '').startswith(('application/', 'binary/')):
121+
return response.url
122+
123+
action_match = re.search(r'<form[^>]*id="download-form"[^>]*action="([^"]+)"', response.text)
124+
if action_match is None:
125+
return response.url
126+
127+
params = dict(re.findall(r'<input type="hidden" name="([^"]+)" value="([^"]*)"', response.text))
128+
if not params:
129+
return response.url
130+
131+
return '{}?{}'.format(action_match.group(1), urlencode(params))
132+
133+
134+
def _resolve_download_urls(urls):
135+
resolved_urls = []
136+
for candidate in urls:
137+
if 'drive.google.com' in candidate or 'drive.usercontent.google.com' in candidate:
138+
resolved_urls.append(_resolve_google_drive_url(candidate))
139+
else:
140+
resolved_urls.append(candidate)
141+
return resolved_urls
142+
143+
104144
class PascalVOC:
105145
r"""
106146
Download and preprocess **PascalVOC Keypoint** dataset.
@@ -1027,9 +1067,9 @@ def __init__(self, sets, obj_resize, **ds_dict):
10271067
CLASSES = dataset_cfg.IMC_PT_SparseGM.CLASSES
10281068
ROOT_DIR_NPZ = dataset_cfg.IMC_PT_SparseGM.ROOT_DIR_NPZ
10291069
ROOT_DIR_IMG = dataset_cfg.IMC_PT_SparseGM.ROOT_DIR_IMG
1030-
URL = ['https://huggingface.co/heatingma/pygmtools/resolve/main/IMC-PT-SparseGM.tar.gz',
1031-
'https://drive.google.com/u/0/uc?id=1bisri2Ip1Of3RsUA8OBrdH5oa6HlH3k-&export=download']
1032-
1070+
URL = ['https://huggingface.co/datasets/esflfei/IMC-PT-SparseGM/resolve/main/IMC-PT-SparseGM.tar.gz',
1071+
'https://drive.google.com/uc?export=download&id=1C3xl_eWaCG3lL2C3vP8Fpsck88xZOHtg']
1072+
10331073
if len(ds_dict.keys()) > 0:
10341074
if 'MAX_KPT_NUM' in ds_dict.keys():
10351075
MAX_KPT_NUM = ds_dict['MAX_KPT_NUM']
@@ -1079,7 +1119,7 @@ def download(self, url=None, retries=15):
10791119
os.makedirs(dirs)
10801120
print('Downloading dataset IMC-PT-SparseGM...')
10811121
filename = 'data/IMC-PT-SparseGM.tar.gz'
1082-
download(filename=filename, url=url, to_cache=False)
1122+
download(filename=filename, url=_resolve_download_urls(url), to_cache=False)
10831123
try:
10841124
tar = tarfile.open(filename, "r")
10851125
except tarfile.ReadError as err:
@@ -1202,7 +1242,7 @@ def __init__(self, sets, obj_resize, **ds_dict):
12021242
CLS_SPLIT = dataset_cfg.CUB2011.CLASS_SPLIT
12031243
ROOT_DIR = dataset_cfg.CUB2011.ROOT_DIR
12041244
URL = ['https://huggingface.co/heatingma/pygmtools/resolve/main/CUB_200_2011.tgz',
1205-
'https://drive.google.com/u/0/uc?export=download&confirm=B8eu&id=1hbzc_P1FuxMkcabkgn9ZKinBwW683j45']
1245+
'https://drive.google.com/uc?export=download&id=1jUTTlx8vv7doQUHSCdl39laV-y0FBC61']
12061246
if len(ds_dict.keys()) > 0:
12071247
if 'ROOT_DIR' in ds_dict.keys():
12081248
ROOT_DIR = ds_dict['ROOT_DIR']
@@ -1276,7 +1316,7 @@ def download(self, url=None, retries=50):
12761316
os.makedirs(dirs)
12771317
print('Downloading dataset CUB2011...')
12781318
filename = 'data/CUB_200_2011.tgz'
1279-
download(filename=filename, url=url, to_cache=False)
1319+
download(filename=filename, url=_resolve_download_urls(url), to_cache=False)
12801320
try:
12811321
tar = tarfile.open(filename, "r")
12821322
except tarfile.ReadError as err:

0 commit comments

Comments
 (0)