Skip to content

Commit 3b642ff

Browse files
authored
Merge pull request #1634 from ceb8/cache_timeout
ENH: Cache refactoring
2 parents b07d5db + 016f6b4 commit 3b642ff

File tree

4 files changed

+356
-26
lines changed

4 files changed

+356
-26
lines changed

astroquery/__init__.py

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
import logging
1717

1818
from .logger import _init_log
19+
from astropy import config as _config
1920

2021
__all__ = ["__version__", "__githash__", "__citation__", "__bibtex__", "test", "log"]
2122

@@ -38,3 +39,23 @@ def _get_bibtex():
3839
logging.addLevelName(5, "TRACE")
3940
log = logging.getLogger()
4041
log = _init_log()
42+
43+
44+
# Set up cache configuration
45+
class Cache_Conf(_config.ConfigNamespace):
46+
47+
cache_timeout = _config.ConfigItem(
48+
604800,
49+
('Astroquery-wide cache timeout (seconds). Default is 1 week (604800). '
50+
'Setting to None prevents the cache from expiring (not recommended).'),
51+
cfgtype='integer'
52+
)
53+
54+
cache_active = _config.ConfigItem(
55+
True,
56+
"Astroquery global cache usage, False turns off all caching.",
57+
cfgtype='boolean'
58+
)
59+
60+
61+
cache_conf = Cache_Conf()

astroquery/query.py

Lines changed: 66 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -11,26 +11,31 @@
1111
import requests
1212
import textwrap
1313

14+
from datetime import datetime, timedelta
15+
from pathlib import Path
16+
1417
from astropy.config import paths
15-
from astroquery import log
1618
import astropy.units as u
1719
from astropy.utils.console import ProgressBarOrSpinner
1820
import astropy.utils.data
21+
from astropy.utils import deprecated
22+
23+
from astroquery import version, log, cache_conf
24+
from astroquery.utils import system_tools
1925

20-
from . import version
21-
from .utils import system_tools
2226

2327
__all__ = ['BaseQuery', 'QueryWithLogin']
2428

2529

2630
def to_cache(response, cache_file):
2731
log.debug("Caching data to {0}".format(cache_file))
32+
2833
response = copy.deepcopy(response)
2934
if hasattr(response, 'request'):
3035
for key in tuple(response.request.hooks.keys()):
3136
del response.request.hooks[key]
3237
with open(cache_file, "wb") as f:
33-
pickle.dump(response, f)
38+
pickle.dump(response, f, protocol=4)
3439

3540

3641
def _replace_none_iterable(iterable):
@@ -102,20 +107,30 @@ def hash(self):
102107
return self._hash
103108

104109
def request_file(self, cache_location):
105-
fn = os.path.join(cache_location, self.hash() + ".pickle")
110+
fn = cache_location.joinpath(self.hash() + ".pickle")
106111
return fn
107112

108-
def from_cache(self, cache_location):
113+
def from_cache(self, cache_location, cache_timeout):
109114
request_file = self.request_file(cache_location)
110115
try:
111-
with open(request_file, "rb") as f:
112-
response = pickle.load(f)
113-
if not isinstance(response, requests.Response):
116+
if cache_timeout is None:
117+
expired = False
118+
else:
119+
current_time = datetime.utcnow()
120+
cache_time = datetime.utcfromtimestamp(request_file.stat().st_mtime)
121+
expired = current_time-cache_time > timedelta(seconds=cache_timeout)
122+
if not expired:
123+
with open(request_file, "rb") as f:
124+
response = pickle.load(f)
125+
if not isinstance(response, requests.Response):
126+
response = None
127+
else:
128+
log.debug(f"Cache expired for {request_file}...")
114129
response = None
115130
except FileNotFoundError:
116131
response = None
117132
if response:
118-
log.debug("Retrieving data from {0}".format(request_file))
133+
log.debug("Retrieved data from {0}".format(request_file))
119134
return response
120135

121136
def remove_cache_file(self, cache_location):
@@ -125,8 +140,8 @@ def remove_cache_file(self, cache_location):
125140
"""
126141
request_file = self.request_file(cache_location)
127142

128-
if os.path.exists(request_file):
129-
os.remove(request_file)
143+
if request_file.exists:
144+
request_file.unlink()
130145
else:
131146
raise FileNotFoundError(f"Tried to remove cache file {request_file} but "
132147
"it does not exist")
@@ -173,11 +188,8 @@ def __init__(self):
173188
.format(vers=version.version,
174189
olduseragent=S.headers['User-Agent']))
175190

176-
self.cache_location = os.path.join(
177-
paths.get_cache_dir(), 'astroquery',
178-
self.__class__.__name__.split("Class")[0])
179-
os.makedirs(self.cache_location, exist_ok=True)
180-
self._cache_active = True
191+
self.name = self.__class__.__name__.split("Class")[0]
192+
self._cache_location = None
181193

182194
def __call__(self, *args, **kwargs):
183195
""" init a fresh copy of self """
@@ -217,9 +229,28 @@ def _response_hook(self, response, *args, **kwargs):
217229
f"-----------------------------------------", '\t')
218230
log.log(5, f"HTTP response\n{response_log}")
219231

232+
@property
233+
def cache_location(self):
234+
cl = self._cache_location or Path(paths.get_cache_dir(), 'astroquery', self.name)
235+
cl.mkdir(parents=True, exist_ok=True)
236+
return cl
237+
238+
@cache_location.setter
239+
def cache_location(self, loc):
240+
self._cache_location = Path(loc)
241+
242+
def reset_cache_location(self):
243+
"""Resets the cache location to the default astropy cache"""
244+
self._cache_location = None
245+
246+
def clear_cache(self):
247+
"""Removes all cache files."""
248+
for fle in self.cache_location.glob("*.pickle"):
249+
fle.unlink()
250+
220251
def _request(self, method, url,
221252
params=None, data=None, headers=None,
222-
files=None, save=False, savedir='', timeout=None, cache=True,
253+
files=None, save=False, savedir='', timeout=None, cache=None,
223254
stream=False, auth=None, continuation=True, verify=True,
224255
allow_redirects=True,
225256
json=None, return_response_on_save=False):
@@ -253,6 +284,7 @@ def _request(self, method, url,
253284
somewhere other than `BaseQuery.cache_location`
254285
timeout : int
255286
cache : bool
287+
Optional, if specified, overrides global cache settings.
256288
verify : bool
257289
Verify the server's TLS certificate?
258290
(see http://docs.python-requests.org/en/master/_modules/requests/sessions/?highlight=verify)
@@ -278,12 +310,16 @@ def _request(self, method, url,
278310
is True.
279311
"""
280312

313+
if cache is None: # Global caching not overridden
314+
cache = cache_conf.cache_active
315+
281316
if save:
282317
local_filename = url.split('/')[-1]
283318
if os.name == 'nt':
284319
# Windows doesn't allow special characters in filenames like
285320
# ":" so replace them with an underscore
286321
local_filename = local_filename.replace(':', '_')
322+
287323
local_filepath = os.path.join(savedir or self.cache_location or '.', local_filename)
288324

289325
response = self._download_file(url, local_filepath, cache=cache, timeout=timeout,
@@ -298,14 +334,14 @@ def _request(self, method, url,
298334
else:
299335
query = AstroQuery(method, url, params=params, data=data, headers=headers,
300336
files=files, timeout=timeout, json=json)
301-
if ((self.cache_location is None) or (not self._cache_active) or (not cache)):
302-
with suspend_cache(self):
337+
if not cache:
338+
with cache_conf.set_temp("cache_active", False):
303339
response = query.request(self._session, stream=stream,
304340
auth=auth, verify=verify,
305341
allow_redirects=allow_redirects,
306342
json=json)
307343
else:
308-
response = query.from_cache(self.cache_location)
344+
response = query.from_cache(self.cache_location, cache_conf.cache_timeout)
309345
if not response:
310346
response = query.request(self._session,
311347
self.cache_location,
@@ -315,6 +351,7 @@ def _request(self, method, url,
315351
verify=verify,
316352
json=json)
317353
to_cache(response, query.request_file(self.cache_location))
354+
318355
self._last_query = query
319356
return response
320357

@@ -336,6 +373,7 @@ def _download_file(self, url, local_filepath, timeout=None, auth=None,
336373
supports HTTP "range" requests, the download will be continued
337374
where it left off.
338375
cache : bool
376+
Cache downloaded file. Defaults to False.
339377
method : "GET" or "POST"
340378
head_safe : bool
341379
"""
@@ -439,19 +477,21 @@ def _download_file(self, url, local_filepath, timeout=None, auth=None,
439477
return response
440478

441479

480+
@deprecated(since="v0.4.7", message=("The suspend_cache function is deprecated,"
481+
"Use the conf set_temp function instead."))
442482
class suspend_cache:
443483
"""
444484
A context manager that suspends caching.
445485
"""
446486

447-
def __init__(self, obj):
448-
self.obj = obj
487+
def __init__(self, obj=None):
488+
self.original_cache_setting = cache_conf.cache_active
449489

450490
def __enter__(self):
451-
self.obj._cache_active = False
491+
cache_conf.cache_active = False
452492

453493
def __exit__(self, exc_type, exc_value, traceback):
454-
self.obj._cache_active = True
494+
cache_conf.cache_active = self.original_cache_setting
455495
return False
456496

457497

@@ -507,7 +547,7 @@ def _login(self, *args, **kwargs):
507547
pass
508548

509549
def login(self, *args, **kwargs):
510-
with suspend_cache(self):
550+
with cache_conf.set_temp("cache_active", False):
511551
self._authenticated = self._login(*args, **kwargs)
512552
return self._authenticated
513553

0 commit comments

Comments
 (0)