Skip to content

Commit 4b0e6a5

Browse files
authored
Merge pull request #1989 from marxide/clear-session-header-range
Delete range from session request header after use
2 parents 809fdfb + 9bd096b commit 4b0e6a5

File tree

2 files changed

+39
-17
lines changed

2 files changed

+39
-17
lines changed

astroquery/query.py

Lines changed: 18 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -210,7 +210,7 @@ def _request(self, method, url,
210210
files=None, save=False, savedir='', timeout=None, cache=True,
211211
stream=False, auth=None, continuation=True, verify=True,
212212
allow_redirects=True,
213-
json=None):
213+
json=None, return_response_on_save=False):
214214
"""
215215
A generic HTTP request method, similar to `requests.Session.request`
216216
but with added caching-related tools
@@ -249,14 +249,21 @@ def _request(self, method, url,
249249
parameter will try to continue the download where it left off.
250250
See `_download_file`.
251251
stream : bool
252+
return_response_on_save : bool
253+
If ``save``, also return the server response. The default is to only
254+
return the local file path.
252255
253256
Returns
254257
-------
255258
response : `requests.Response`
256259
The response from the server if ``save`` is False
257260
local_filepath : list
258261
a list of strings containing the downloaded local paths if ``save``
259-
is True
262+
is True and ``return_response_on_save`` is False.
263+
(local_filepath, response) : tuple(list, `requests.Response`)
264+
a tuple containing a list of strings containing the downloaded local paths,
265+
and the server response object, if ``save`` is True and ``return_response_on_save``
266+
is True.
260267
"""
261268
req_kwargs = dict(
262269
params=params,
@@ -274,11 +281,14 @@ def _request(self, method, url,
274281
local_filename = local_filename.replace(':', '_')
275282
local_filepath = os.path.join(savedir or self.cache_location or '.', local_filename)
276283

277-
self._download_file(url, local_filepath, cache=cache,
278-
continuation=continuation, method=method,
279-
allow_redirects=allow_redirects,
280-
auth=auth, **req_kwargs)
281-
return local_filepath
284+
response = self._download_file(url, local_filepath, cache=cache,
285+
continuation=continuation, method=method,
286+
allow_redirects=allow_redirects,
287+
auth=auth, **req_kwargs)
288+
if return_response_on_save:
289+
return local_filepath, response
290+
else:
291+
return local_filepath
282292
else:
283293
query = AstroQuery(method, url, **req_kwargs)
284294
if ((self.cache_location is None) or (not self._cache_active) or (not cache)):
@@ -369,6 +379,7 @@ def _download_file(self, url, local_filepath, timeout=None, auth=None,
369379
timeout=timeout, stream=True,
370380
auth=auth, **kwargs)
371381
response.raise_for_status()
382+
del self._session.headers['Range']
372383

373384
elif cache and os.path.exists(local_filepath):
374385
if length is not None:

astroquery/tests/test_resume.py

Lines changed: 21 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -9,18 +9,17 @@
99

1010

1111
@pytest.mark.skipif('not ACTIVE_HTTPBIN')
12-
def test_resume():
12+
def test_resume(length=2048, partial_length=1024, qu=None):
1313
# Test that a resumed query will finish
1414

15-
length = 2048
16-
1715
# 'range' will return abcd...xyz repeating
1816
target_url = 'http://127.0.0.1:5000/range/{0}'.format(length)
1917

2018
# simple check: make sure the server does what's expected
2119
assert len(requests.get(target_url).content) == length
2220

23-
qu = query.BaseQuery()
21+
if qu is None:
22+
qu = query.BaseQuery()
2423

2524
result_1 = qu._request('GET', target_url, save=True)
2625
# now the full file is written, so we have to delete parts of it
@@ -29,17 +28,29 @@ def test_resume():
2928
data = fh.read()
3029
with open(result_1, 'wb') as fh:
3130
# overwrite with a partial file
32-
fh.write(data[:1024])
31+
fh.write(data[:partial_length])
3332
with open(result_1, 'rb') as fh:
3433
data = fh.read()
35-
assert len(data) == 1024
34+
assert len(data) == partial_length
3635

37-
result_2 = qu._request('GET', target_url, save=True, continuation=True)
36+
result_2, response = qu._request('GET', target_url, save=True, continuation=True,
37+
return_response_on_save=True)
3838

39-
assert 'range' in qu._session.headers
40-
assert qu._session.headers['range'] == 'bytes={0}-{1}'.format(1024, length-1)
39+
assert 'content-range' in response.headers
40+
assert response.headers['content-range'] == 'bytes {0}-{1}/{2}'.format(
41+
partial_length, length-1, length
42+
)
4143

4244
with open(result_2, 'rb') as fh:
4345
data = fh.read()
4446
assert len(data) == length
45-
assert data == (string.ascii_lowercase*80)[:length].encode('ascii')
47+
assert data == (string.ascii_lowercase*(length//26+1))[:length].encode('ascii')
48+
49+
50+
@pytest.mark.skipif('not ACTIVE_HTTPBIN')
51+
def test_resume_consecutive():
52+
# Test that consecutive resumed queries request the correct content range and finish
53+
qu = query.BaseQuery()
54+
55+
test_resume(length=2048, partial_length=1024, qu=qu)
56+
test_resume(length=2048, partial_length=512, qu=qu)

0 commit comments

Comments
 (0)