diff --git a/download_ffhq.py b/download_ffhq.py index b626ebce7..b6d300d7c 100755 --- a/download_ffhq.py +++ b/download_ffhq.py @@ -27,6 +27,7 @@ import itertools import shutil from collections import OrderedDict, defaultdict +import os.path as osp PIL.ImageFile.LOAD_TRUNCATED_IMAGES = True # avoid "Decompressed Data Too Large" error @@ -54,22 +55,31 @@ 'tfrecords': dict(file_url='https://drive.google.com/uc?id=1SYUmqKdLoTYq-kqsnPsniLScMhspvl5v', file_path='tfrecords/ffhq/LICENSE.txt', file_size=1610, file_md5='724f3831aaecd61a84fe98500079abc2'), } + +headers = { + "User-Agent": "Mozilla/5.0 (Macintosh; Intel Mac OS X 10_10_1) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/39.0.2171.95 Safari/537.36" # NOQA + } + +home = osp.expanduser("~") + #---------------------------------------------------------------------------- def download_file(session, file_spec, stats, chunk_size=128, num_attempts=10, **kwargs): file_path = file_spec['file_path'] file_url = file_spec['file_url'] + file_url = file_url.split('?')[0] + '?export=download&confirm=no_antivirus&' + file_url.split('?')[1] file_dir = os.path.dirname(file_path) tmp_path = file_path + '.tmp.' + uuid.uuid4().hex if file_dir: os.makedirs(file_dir, exist_ok=True) + for attempts_left in reversed(range(num_attempts)): data_size = 0 try: # Download. data_md5 = hashlib.md5() - with session.get(file_url, stream=True) as res: + with session.get(file_url, headers=headers, stream=True) as res: res.raise_for_status() with open(tmp_path, 'wb') as f: for chunk in res.iter_content(chunk_size=chunk_size<<10): @@ -118,6 +128,17 @@ def download_file(session, file_spec, stats, chunk_size=128, num_attempts=10, ** if not attempts_left: raise + cache_dir = osp.join(home, ".cache", "ffhq_download") + cookies_file = osp.join(cache_dir, "cookies.json") + + with open(cookies_file, "w") as f: + cookies = [ + (k, v) + for k, v in session.cookies.items() + if not k.startswith("download_warning_") + ] + json.dump(cookies, f, indent=2) + # Rename temp file to the correct name. os.replace(tmp_path, file_path) # atomic with stats['lock']: @@ -164,6 +185,10 @@ def download_files(file_specs, num_threads=32, status_delay=0.2, timing_window=5 print('All files already downloaded -- skipping.') return + cache_dir = osp.join(home, ".cache", "ffhq_download") + if not osp.exists(cache_dir): + os.makedirs(cache_dir) + # Launch worker threads. spec_queue = queue.Queue() exception_queue = queue.Queue() @@ -208,6 +233,20 @@ def download_files(file_specs, num_threads=32, status_delay=0.2, timing_window=5 def _download_thread(spec_queue, exception_queue, stats, download_kwargs): with requests.Session() as session: + + # Load cookies + cache_dir = osp.join(home, ".cache", "ffhq_download") + cookies_file = osp.join(cache_dir, "cookies.json") + if not osp.exists(cookies_file): + session.get('https://www.google.com') + cookies = [(k, v) for k, v in session.cookies.items() if not k.startswith("download_warning_")] + with open(cookies_file, "w") as f: + json.dump(cookies, f, indent=2) + with open(cookies_file) as f: + cookies = json.load(f) + for k, v in cookies: + session.cookies[k] = v + while not spec_queue.empty(): spec = spec_queue.get() try: