11
11
import requests
12
12
import textwrap
13
13
14
+ from datetime import datetime , timedelta
15
+ from pathlib import Path
16
+
14
17
from astropy .config import paths
15
- from astroquery import log
16
18
import astropy .units as u
17
19
from astropy .utils .console import ProgressBarOrSpinner
18
20
import astropy .utils .data
21
+ from astropy .utils import deprecated
22
+
23
+ from astroquery import version , log , cache_conf
24
+ from astroquery .utils import system_tools
19
25
20
- from . import version
21
- from .utils import system_tools
22
26
23
27
__all__ = ['BaseQuery' , 'QueryWithLogin' ]
24
28
25
29
26
30
def to_cache (response , cache_file ):
27
31
log .debug ("Caching data to {0}" .format (cache_file ))
32
+
28
33
response = copy .deepcopy (response )
29
34
if hasattr (response , 'request' ):
30
35
for key in tuple (response .request .hooks .keys ()):
31
36
del response .request .hooks [key ]
32
37
with open (cache_file , "wb" ) as f :
33
- pickle .dump (response , f )
38
+ pickle .dump (response , f , protocol = 4 )
34
39
35
40
36
41
def _replace_none_iterable (iterable ):
@@ -102,20 +107,30 @@ def hash(self):
102
107
return self ._hash
103
108
104
109
def request_file (self , cache_location ):
105
- fn = os . path . join ( cache_location , self .hash () + ".pickle" )
110
+ fn = cache_location . joinpath ( self .hash () + ".pickle" )
106
111
return fn
107
112
108
- def from_cache (self , cache_location ):
113
+ def from_cache (self , cache_location , cache_timeout ):
109
114
request_file = self .request_file (cache_location )
110
115
try :
111
- with open (request_file , "rb" ) as f :
112
- response = pickle .load (f )
113
- if not isinstance (response , requests .Response ):
116
+ if cache_timeout is None :
117
+ expired = False
118
+ else :
119
+ current_time = datetime .utcnow ()
120
+ cache_time = datetime .utcfromtimestamp (request_file .stat ().st_mtime )
121
+ expired = current_time - cache_time > timedelta (seconds = cache_timeout )
122
+ if not expired :
123
+ with open (request_file , "rb" ) as f :
124
+ response = pickle .load (f )
125
+ if not isinstance (response , requests .Response ):
126
+ response = None
127
+ else :
128
+ log .debug (f"Cache expired for { request_file } ..." )
114
129
response = None
115
130
except FileNotFoundError :
116
131
response = None
117
132
if response :
118
- log .debug ("Retrieving data from {0}" .format (request_file ))
133
+ log .debug ("Retrieved data from {0}" .format (request_file ))
119
134
return response
120
135
121
136
def remove_cache_file (self , cache_location ):
@@ -125,8 +140,8 @@ def remove_cache_file(self, cache_location):
125
140
"""
126
141
request_file = self .request_file (cache_location )
127
142
128
- if os . path . exists ( request_file ) :
129
- os . remove ( request_file )
143
+ if request_file . exists :
144
+ request_file . unlink ( )
130
145
else :
131
146
raise FileNotFoundError (f"Tried to remove cache file { request_file } but "
132
147
"it does not exist" )
@@ -173,11 +188,8 @@ def __init__(self):
173
188
.format (vers = version .version ,
174
189
olduseragent = S .headers ['User-Agent' ]))
175
190
176
- self .cache_location = os .path .join (
177
- paths .get_cache_dir (), 'astroquery' ,
178
- self .__class__ .__name__ .split ("Class" )[0 ])
179
- os .makedirs (self .cache_location , exist_ok = True )
180
- self ._cache_active = True
191
+ self .name = self .__class__ .__name__ .split ("Class" )[0 ]
192
+ self ._cache_location = None
181
193
182
194
def __call__ (self , * args , ** kwargs ):
183
195
""" init a fresh copy of self """
@@ -217,9 +229,28 @@ def _response_hook(self, response, *args, **kwargs):
217
229
f"-----------------------------------------" , '\t ' )
218
230
log .log (5 , f"HTTP response\n { response_log } " )
219
231
232
+ @property
233
+ def cache_location (self ):
234
+ cl = self ._cache_location or Path (paths .get_cache_dir (), 'astroquery' , self .name )
235
+ cl .mkdir (parents = True , exist_ok = True )
236
+ return cl
237
+
238
+ @cache_location .setter
239
+ def cache_location (self , loc ):
240
+ self ._cache_location = Path (loc )
241
+
242
+ def reset_cache_location (self ):
243
+ """Resets the cache location to the default astropy cache"""
244
+ self ._cache_location = None
245
+
246
+ def clear_cache (self ):
247
+ """Removes all cache files."""
248
+ for fle in self .cache_location .glob ("*.pickle" ):
249
+ fle .unlink ()
250
+
220
251
def _request (self , method , url ,
221
252
params = None , data = None , headers = None ,
222
- files = None , save = False , savedir = '' , timeout = None , cache = True ,
253
+ files = None , save = False , savedir = '' , timeout = None , cache = None ,
223
254
stream = False , auth = None , continuation = True , verify = True ,
224
255
allow_redirects = True ,
225
256
json = None , return_response_on_save = False ):
@@ -253,6 +284,7 @@ def _request(self, method, url,
253
284
somewhere other than `BaseQuery.cache_location`
254
285
timeout : int
255
286
cache : bool
287
+ Optional, if specified, overrides global cache settings.
256
288
verify : bool
257
289
Verify the server's TLS certificate?
258
290
(see http://docs.python-requests.org/en/master/_modules/requests/sessions/?highlight=verify)
@@ -278,12 +310,16 @@ def _request(self, method, url,
278
310
is True.
279
311
"""
280
312
313
+ if cache is None : # Global caching not overridden
314
+ cache = cache_conf .cache_active
315
+
281
316
if save :
282
317
local_filename = url .split ('/' )[- 1 ]
283
318
if os .name == 'nt' :
284
319
# Windows doesn't allow special characters in filenames like
285
320
# ":" so replace them with an underscore
286
321
local_filename = local_filename .replace (':' , '_' )
322
+
287
323
local_filepath = os .path .join (savedir or self .cache_location or '.' , local_filename )
288
324
289
325
response = self ._download_file (url , local_filepath , cache = cache , timeout = timeout ,
@@ -298,14 +334,14 @@ def _request(self, method, url,
298
334
else :
299
335
query = AstroQuery (method , url , params = params , data = data , headers = headers ,
300
336
files = files , timeout = timeout , json = json )
301
- if (( self . cache_location is None ) or ( not self . _cache_active ) or ( not cache )) :
302
- with suspend_cache ( self ):
337
+ if not cache :
338
+ with cache_conf . set_temp ( "cache_active" , False ):
303
339
response = query .request (self ._session , stream = stream ,
304
340
auth = auth , verify = verify ,
305
341
allow_redirects = allow_redirects ,
306
342
json = json )
307
343
else :
308
- response = query .from_cache (self .cache_location )
344
+ response = query .from_cache (self .cache_location , cache_conf . cache_timeout )
309
345
if not response :
310
346
response = query .request (self ._session ,
311
347
self .cache_location ,
@@ -315,6 +351,7 @@ def _request(self, method, url,
315
351
verify = verify ,
316
352
json = json )
317
353
to_cache (response , query .request_file (self .cache_location ))
354
+
318
355
self ._last_query = query
319
356
return response
320
357
@@ -336,6 +373,7 @@ def _download_file(self, url, local_filepath, timeout=None, auth=None,
336
373
supports HTTP "range" requests, the download will be continued
337
374
where it left off.
338
375
cache : bool
376
+ Cache downloaded file. Defaults to False.
339
377
method : "GET" or "POST"
340
378
head_safe : bool
341
379
"""
@@ -439,19 +477,21 @@ def _download_file(self, url, local_filepath, timeout=None, auth=None,
439
477
return response
440
478
441
479
480
+ @deprecated (since = "v0.4.7" , message = ("The suspend_cache function is deprecated,"
481
+ "Use the conf set_temp function instead." ))
442
482
class suspend_cache :
443
483
"""
444
484
A context manager that suspends caching.
445
485
"""
446
486
447
- def __init__ (self , obj ):
448
- self .obj = obj
487
+ def __init__ (self , obj = None ):
488
+ self .original_cache_setting = cache_conf . cache_active
449
489
450
490
def __enter__ (self ):
451
- self . obj . _cache_active = False
491
+ cache_conf . cache_active = False
452
492
453
493
def __exit__ (self , exc_type , exc_value , traceback ):
454
- self . obj . _cache_active = True
494
+ cache_conf . cache_active = self . original_cache_setting
455
495
return False
456
496
457
497
@@ -507,7 +547,7 @@ def _login(self, *args, **kwargs):
507
547
pass
508
548
509
549
def login (self , * args , ** kwargs ):
510
- with suspend_cache ( self ):
550
+ with cache_conf . set_temp ( "cache_active" , False ):
511
551
self ._authenticated = self ._login (* args , ** kwargs )
512
552
return self ._authenticated
513
553
0 commit comments