Skip to content

Commit 6df612a

Browse files
committed
integrating better with astropy config framework
1 parent 2f5fef9 commit 6df612a

File tree

3 files changed

+95
-80
lines changed

3 files changed

+95
-80
lines changed

astroquery/__init__.py

Lines changed: 10 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -54,18 +54,17 @@ def _get_bibtex():
5454
# Set up cache configuration
5555
class Conf(_config.ConfigNamespace):
5656

57-
default_cache_timeout = _config.ConfigItem(
58-
604800, # 1 week
59-
'Astroquery-wide default cache timeout (seconds).'
60-
)
61-
default_cache_location = _config.ConfigItem(
62-
os.path.join(_config.paths.get_cache_dir(), 'astroquery'),
63-
'Astroquery default cache location (within astropy cache).'
64-
)
65-
default_cache_active = _config.ConfigItem(
57+
cache_timeout = _config.ConfigItem(
58+
604800, # 1 week
59+
'Astroquery-wide cache timeout (seconds).',
60+
cfgtype='integer'
61+
)
62+
63+
cache_active = _config.ConfigItem(
6664
True,
67-
"Astroquery global cache usage, False turns off all caching."
68-
)
65+
"Astroquery global cache usage, False turns off all caching.",
66+
cfgtype='boolean'
67+
)
6968

7069

7170
conf = Conf()

astroquery/query.py

Lines changed: 23 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -33,9 +33,6 @@ def to_cache(response, cache_file):
3333
if hasattr(response, 'request'):
3434
for key in tuple(response.request.hooks.keys()):
3535
del response.request.hooks[key]
36-
37-
chache_dir, _ = os.path.split(cache_file)
38-
Path(chache_dir).mkdir(parents=True, exist_ok=True)
3936
with open(cache_file, "wb") as f:
4037
pickle.dump(response, f)
4138

@@ -191,7 +188,7 @@ def __init__(self):
191188
olduseragent=S.headers['User-Agent']))
192189

193190
self.name = self.__class__.__name__.split("Class")[0]
194-
self.reset_cache_preferences()
191+
self.cache_location = None
195192

196193
def __call__(self, *args, **kwargs):
197194
""" init a fresh copy of self """
@@ -231,23 +228,24 @@ def _response_hook(self, response, *args, **kwargs):
231228
f"-----------------------------------------", '\t')
232229
log.log(5, f"HTTP response\n{response_log}")
233230

234-
def clear_cache(self):
235-
"""Removes all cache files."""
231+
@property
232+
def _cache_location(self):
233+
cl = self.cache_location or os.path.join(paths.get_cache_dir(), 'astroquery', self.name)
234+
Path(cl).mkdir(parents=True, exist_ok=True)
235+
return cl
236236

237-
cache_files = [x for x in os.listdir(self.cache_location) if x.endswith("pickle")]
238-
for fle in cache_files:
239-
os.remove(os.path.join(self.cache_location, fle))
237+
def get_cache_location(self):
238+
return self._cache_location
240239

241-
def reset_cache_preferences(self):
242-
"""Resets cache preferences to default values"""
240+
def reset_cache_location(self):
241+
self.cache_location = None
243242

244-
self.cache_location = os.path.join(
245-
conf.default_cache_location,
246-
self.__class__.__name__.split("Class")[0])
247-
os.makedirs(self.cache_location, exist_ok=True)
243+
def clear_cache(self):
244+
"""Removes all cache files."""
248245

249-
self.cache_active = conf.default_cache_active
250-
self.cache_timeout = conf.default_cache_timeout
246+
cache_files = [x for x in os.listdir(self._cache_location) if x.endswith("pickle")]
247+
for fle in cache_files:
248+
os.remove(os.path.join(self._cache_location, fle))
251249

252250
def _request(self, method, url,
253251
params=None, data=None, headers=None,
@@ -319,7 +317,7 @@ def _request(self, method, url,
319317
json=json
320318
)
321319

322-
if (cache is not False) and self.cache_active:
320+
if (cache is not False) and conf.cache_active:
323321
cache = True
324322
else:
325323
cache = False
@@ -331,7 +329,7 @@ def _request(self, method, url,
331329
# ":" so replace them with an underscore
332330
local_filename = local_filename.replace(':', '_')
333331

334-
local_filepath = os.path.join(savedir or self.cache_location or '.', local_filename)
332+
local_filepath = os.path.join(savedir or self._cache_location or '.', local_filename)
335333

336334
response = self._download_file(url, local_filepath, cache=cache,
337335
continuation=continuation, method=method,
@@ -343,23 +341,23 @@ def _request(self, method, url,
343341
return local_filepath
344342
else:
345343
query = AstroQuery(method, url, **req_kwargs)
346-
if ((self.cache_location is None) or (not self.cache_active) or (not cache)):
347-
with suspend_cache(self):
344+
if ((self._cache_location is None) or (not cache)):
345+
with conf.set_temp("cache_active", False):
348346
response = query.request(self._session, stream=stream,
349347
auth=auth, verify=verify,
350348
allow_redirects=allow_redirects,
351349
json=json)
352350
else:
353-
response = query.from_cache(self.cache_location, self.cache_timeout)
351+
response = query.from_cache(self._cache_location, conf.cache_timeout)
354352
if not response:
355353
response = query.request(self._session,
356-
self.cache_location,
354+
self._cache_location,
357355
stream=stream,
358356
auth=auth,
359357
allow_redirects=allow_redirects,
360358
verify=verify,
361359
json=json)
362-
to_cache(response, query.request_file(self.cache_location))
360+
to_cache(response, query.request_file(self._cache_location))
363361

364362
self._last_query = query
365363
return response
@@ -486,23 +484,6 @@ def _download_file(self, url, local_filepath, timeout=None, auth=None,
486484
return response
487485

488486

489-
class suspend_cache:
490-
"""
491-
A context manager that suspends caching.
492-
"""
493-
494-
def __init__(self, obj):
495-
self.obj = obj
496-
self.original_cache_setting = self.obj.cache_active
497-
498-
def __enter__(self):
499-
self.obj.cache_active = False
500-
501-
def __exit__(self, exc_type, exc_value, traceback):
502-
self.obj.cache_active = self.original_cache_setting
503-
return False
504-
505-
506487
class QueryWithLogin(BaseQuery):
507488
"""
508489
This is the base class for all the query classes which are required to
@@ -555,7 +536,7 @@ def _login(self, *args, **kwargs):
555536
pass
556537

557538
def login(self, *args, **kwargs):
558-
with suspend_cache(self):
539+
with conf.set_temp("cache_active", False):
559540
self._authenticated = self._login(*args, **kwargs)
560541
return self._authenticated
561542

astroquery/tests/test_cache.py

Lines changed: 62 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -3,8 +3,12 @@
33
import os
44

55
from time import sleep
6+
from pathlib import Path
7+
8+
from astropy.config import paths
69

710
from astroquery.query import QueryWithLogin
11+
from astroquery import conf
812

913
URL1 = "http://fakeurl.edu"
1014
URL2 = "http://fakeurl.ac.uk"
@@ -46,68 +50,99 @@ def _login(self, username):
4650
return False
4751

4852

49-
def test_cache_reset():
50-
mytest = TestClass()
51-
assert mytest.cache_active
53+
def test_conf():
54+
default_timeout = conf.cache_timeout
55+
default_active = conf.cache_active
56+
57+
assert default_timeout == 604800
58+
assert default_active is True
5259

53-
default_timeout = mytest.cache_timeout
54-
default_loc = mytest.cache_location
60+
with conf.set_temp("cache_timeout", 5):
61+
assert conf.cache_timeout == 5
5562

56-
mytest.cache_timeout = 5
57-
mytest.cache_location = "new/location"
63+
with conf.set_temp("cache_active", False):
64+
assert conf.cache_active is False
5865

59-
mytest.reset_cache_preferences()
66+
assert conf.cache_timeout == default_timeout
67+
assert conf.cache_active == default_active
6068

61-
assert mytest.cache_timeout == default_timeout
62-
assert mytest.cache_location == default_loc
69+
conf.cache_timeout = 5
70+
conf.cache_active = False
71+
conf.reset()
72+
73+
assert conf.cache_timeout == default_timeout
74+
assert conf.cache_active == default_active
6375

6476

6577
def test_basic_caching():
6678

6779
mytest = TestClass()
68-
assert mytest.cache_active
80+
assert conf.cache_active
6981

7082
mytest.clear_cache()
71-
assert len(os.listdir(mytest.cache_location)) == 0
83+
assert len(os.listdir(mytest.get_cache_location())) == 0
7284

7385
set_response(TEXT1)
7486

7587
resp = mytest.test_func(URL1)
7688
assert resp.content == TEXT1
77-
assert len(os.listdir(mytest.cache_location)) == 1
89+
assert len(os.listdir(mytest.get_cache_location())) == 1
7890

7991
set_response(TEXT2)
8092

8193
resp = mytest.test_func(URL2) # query that has not been cached
8294
assert resp.content == TEXT2
83-
assert len(os.listdir(mytest.cache_location)) == 2
95+
assert len(os.listdir(mytest.get_cache_location())) == 2
8496

8597
resp = mytest.test_func(URL1)
8698
assert resp.content == TEXT1 # query that was cached
87-
assert len(os.listdir(mytest.cache_location)) == 2 # no new cache file
99+
assert len(os.listdir(mytest.get_cache_location())) == 2 # no new cache file
88100

89101
mytest.clear_cache()
90-
assert len(os.listdir(mytest.cache_location)) == 0
102+
assert len(os.listdir(mytest.get_cache_location())) == 0
91103

92104
resp = mytest.test_func(URL1)
93105
assert resp.content == TEXT2 # Now get new response
94106

95107

108+
def test_change_location(tmpdir):
109+
110+
mytest = TestClass()
111+
default_cache_location = mytest.get_cache_location()
112+
113+
assert paths.get_cache_dir() in default_cache_location
114+
assert "astroquery" in mytest.get_cache_location()
115+
assert mytest.name in mytest.get_cache_location()
116+
117+
new_loc = os.path.join(tmpdir, "new_dir")
118+
mytest.cache_location = new_loc
119+
assert mytest.get_cache_location() == new_loc
120+
121+
mytest.reset_cache_location()
122+
assert mytest.get_cache_location() == default_cache_location
123+
124+
Path(new_loc).mkdir(parents=True, exist_ok=True)
125+
with paths.set_temp_cache(new_loc):
126+
assert new_loc in mytest.get_cache_location()
127+
assert "astroquery" in mytest.get_cache_location()
128+
assert mytest.name in mytest.get_cache_location()
129+
130+
96131
def test_login():
97132

98133
mytest = TestClass()
99-
assert mytest.cache_active
134+
assert conf.cache_active
100135

101136
mytest.clear_cache()
102-
assert len(os.listdir(mytest.cache_location)) == 0
137+
assert len(os.listdir(mytest.get_cache_location())) == 0
103138

104139
set_response(TEXT1) # Text 1 is set as the approved password
105140

106141
mytest.login("ceb")
107142
assert mytest.authenticated()
108-
assert len(os.listdir(mytest.cache_location)) == 0 # request should not be cached
143+
assert len(os.listdir(mytest.get_cache_location())) == 0 # request should not be cached
109144

110-
set_response(TEXT2) # Text 1 is not the approved password
145+
set_response(TEXT2) # Text 2 is not the approved password
111146

112147
mytest.login("ceb")
113148
assert not mytest.authenticated() # Should not be accessing cache
@@ -116,12 +151,12 @@ def test_login():
116151
def test_timeout():
117152

118153
mytest = TestClass()
119-
assert mytest.cache_active
154+
assert conf.cache_active
120155

121156
mytest.clear_cache()
122-
assert len(os.listdir(mytest.cache_location)) == 0
157+
assert len(os.listdir(mytest.get_cache_location())) == 0
123158

124-
mytest.cache_timeout = 2 # Set to 2 sec so we can reach timeout easily
159+
conf.cache_timeout = 2 # Set to 2 sec so we can reach timeout easily
125160

126161
set_response(TEXT1) # setting the response
127162

@@ -141,19 +176,19 @@ def test_timeout():
141176
def test_deactivate():
142177

143178
mytest = TestClass()
144-
mytest.cache_active = False
179+
conf.cache_active = False
145180

146181
mytest.clear_cache()
147-
assert len(os.listdir(mytest.cache_location)) == 0
182+
assert len(os.listdir(mytest.get_cache_location())) == 0
148183

149184
set_response(TEXT1)
150185

151186
resp = mytest.test_func(URL1)
152187
assert resp.content == TEXT1
153-
assert len(os.listdir(mytest.cache_location)) == 0
188+
assert len(os.listdir(mytest.get_cache_location())) == 0
154189

155190
set_response(TEXT2)
156191

157192
resp = mytest.test_func(URL1)
158193
assert resp.content == TEXT2
159-
assert len(os.listdir(mytest.cache_location)) == 0
194+
assert len(os.listdir(mytest.get_cache_location())) == 0

0 commit comments

Comments
 (0)