Skip to content

Commit cbfe81c

Browse files
authored
Merge pull request #9 from runtingt/trunting/return-paths
Return file paths for downloaded files
2 parents f96425e + 89ddc2e commit cbfe81c

File tree

4 files changed

+131
-17
lines changed

4 files changed

+131
-17
lines changed

README.md

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -155,10 +155,11 @@ or equivalently
155155

156156
```python
157157
id_list = client.find('reactions:"P P--> LQ LQ"', ids='inspire')
158-
client.download(id_list, ids='inspire', file_format='csv')
158+
downloads = client.download(id_list, ids='inspire', file_format='csv')
159+
print(downloads) # {'1222326': ['./hepdata-downloads/HEPData-ins1222326-v1-csv/Table1.csv', ...], ...}
159160
```
160161

161-
downloads four .tar.gz archives containing csv files and unpacks them in the default ```./hepdata-downloads``` directory.
162+
downloads four .tar.gz archives containing csv files and unpacks them in the default ```./hepdata-downloads``` directory. Using the API, a dictionary mapping ids to the downloaded files is returned.
162163

163164
### Example 5 - find table names in records:
164165

hepdata_cli/api.py

Lines changed: 52 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -87,13 +87,19 @@ def download(self, id_list, file_format=None, ids=None, table_name='', download_
8787
:param ids: accepts one of ('inspire', 'hepdata'). It specifies what type of ids have been passed.
8888
:param table_name: restricts download to specific tables.
8989
:param download_dir: defaults to ./hepdata-downloads. Specifies where to download the files.
90+
91+
:return: dictionary mapping id to list of downloaded files.
92+
:rtype: dict[int, list[str]]
9093
"""
9194

92-
urls = self._build_urls(id_list, file_format, ids, table_name)
93-
for url in urls:
95+
url_map = self._build_urls(id_list, file_format, ids, table_name)
96+
file_map = {}
97+
for record_id, url in url_map.items():
9498
if self.verbose is True:
9599
print("Downloading: " + url)
96-
download_url(url, download_dir)
100+
files_downloaded = download_url(url, download_dir)
101+
file_map[record_id] = files_downloaded
102+
return file_map
97103

98104
def fetch_names(self, id_list, ids=None):
99105
"""
@@ -102,9 +108,9 @@ def fetch_names(self, id_list, ids=None):
102108
:param id_list: list of id of records of which to return table names.
103109
:param ids: accepts one of ('inspire', 'hepdata'). It specifies what type of ids have been passed.
104110
"""
105-
urls = self._build_urls(id_list, 'json', ids, '')
111+
url_map = self._build_urls(id_list, 'json', ids, '')
106112
table_names = []
107-
for url in urls:
113+
for url in url_map.values():
108114
response = resilient_requests('get', url)
109115
json_dict = response.json()
110116
table_names += [[data_table['name'] for data_table in json_dict['data_tables']]]
@@ -136,7 +142,16 @@ def upload(self, path_to_file, email, recid=None, invitation_cookie=None, sandbo
136142
print('Uploaded ' + path_to_file + ' to ' + SITE_URL + '/record/' + str(recid))
137143

138144
def _build_urls(self, id_list, file_format, ids, table_name):
139-
"""Builds urls for download and fetch_names, given the specified parameters."""
145+
"""
146+
Builds urls for download and fetch_names, given the specified parameters.
147+
148+
:param id_list: list of ids to download.
149+
:param file_format: accepts one of ('csv', 'root', 'yaml', 'yoda', 'yoda1', 'yoda.h5', 'json').
150+
:param ids: accepts one of ('inspire', 'hepdata').
151+
:param table_name: restricts download to specific tables.
152+
153+
:return: dictionary mapping id to url.
154+
"""
140155
if type(id_list) not in (tuple, list):
141156
id_list = id_list.split()
142157
assert len(id_list) > 0, 'Ids are required.'
@@ -146,9 +161,12 @@ def _build_urls(self, id_list, file_format, ids, table_name):
146161
params = {'format': file_format}
147162
else:
148163
params = {'format': file_format, 'table': table_name}
149-
urls = [resilient_requests('get', SITE_URL + '/record/' + ('ins' if ids == 'inspire' else '') + id_entry, params=params).url.replace('%2525', '%25') for id_entry in id_list]
164+
url_mapping = {}
165+
for id_entry in id_list:
166+
url = resilient_requests('get', SITE_URL + '/record/' + ('ins' if ids == 'inspire' else '') + id_entry, params=params).url.replace('%2525', '%25')
167+
url_mapping[id_entry] = url
150168
# TODO: Investigate root cause of double URL encoding (https://github.com/HEPData/hepdata-cli/issues/8).
151-
return urls
169+
return url_mapping
152170

153171
def _query(self, query, page, size):
154172
"""Builds the search query passed to hepdata.net."""
@@ -170,6 +188,7 @@ def mkdir(directory):
170188

171189
def download_url(url, download_dir):
172190
"""Download file and if necessary extract it."""
191+
files_downloaded = []
173192
assert is_downloadable(url), "Given url is not downloadable: {}".format(url)
174193
response = resilient_requests('get', url, allow_redirects=True)
175194
if url[-4:] == 'json':
@@ -182,10 +201,31 @@ def download_url(url, download_dir):
182201
mkdir(os.path.dirname(filepath))
183202
open(filepath, 'wb').write(response.content)
184203
if filepath.endswith("tar.gz") or filepath.endswith("tar"):
185-
tar = tarfile.open(filepath, "r:gz" if filepath.endswith("tar.gz") else "r:")
186-
tar.extractall(path=os.path.dirname(filepath))
187-
tar.close()
188-
os.remove(filepath)
204+
tar = None
205+
try:
206+
tar = tarfile.open(filepath, "r:gz" if filepath.endswith("tar.gz") else "r:")
207+
extract_dir = os.path.abspath(os.path.dirname(filepath))
208+
tar.extractall(path=os.path.dirname(filepath))
209+
for member in tar.getmembers():
210+
if member.isfile():
211+
extracted_path = os.path.join(os.path.dirname(filepath), member.name)
212+
abs_extracted_path = os.path.abspath(extracted_path)
213+
if abs_extracted_path.startswith(extract_dir + os.sep) and os.path.exists(abs_extracted_path):
214+
files_downloaded.append(abs_extracted_path)
215+
elif not abs_extracted_path.startswith(extract_dir + os.sep):
216+
raise ValueError(f"Attempted path traversal for file {member.name}")
217+
else:
218+
raise FileNotFoundError(f"Extracted file {member.name} not found")
219+
except Exception as e:
220+
raise Exception(f"Failed to extract {filepath}: {str(e)}")
221+
finally:
222+
if tar:
223+
tar.close()
224+
if os.path.exists(filepath):
225+
os.remove(filepath)
226+
else:
227+
files_downloaded.append(filepath)
228+
return files_downloaded
189229

190230

191231
def getFilename_fromCd(cd):

hepdata_cli/version.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1 +1 @@
1-
__version__ = "0.2.3"
1+
__version__ = "0.3.0"

tests/test_download.py

Lines changed: 75 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,10 +3,13 @@
33
import pytest
44
import os
55
import shutil
6+
import tarfile
7+
import tempfile
8+
from unittest.mock import patch
69

710
from click.testing import CliRunner
811

9-
from hepdata_cli.api import Client, mkdir
12+
from hepdata_cli.api import Client, download_url, mkdir
1013
from hepdata_cli.cli import cli
1114

1215

@@ -57,8 +60,10 @@ def test_api_download(id_list, file_format, ids, table):
5760
mkdir(test_download_dir)
5861
assert len(os.listdir(test_download_dir)) == 0
5962
client = Client(verbose=True)
60-
client.download(id_list, file_format, ids, table, test_download_dir)
63+
path_map = client.download(id_list, file_format, ids, table, test_download_dir)
64+
file_paths = [fp for fps in path_map.values() for fp in fps]
6165
assert len(os.listdir(test_download_dir)) > 0
66+
assert all(os.path.exists(fp) for fp in file_paths)
6267
cleanup(test_download_dir)
6368

6469

@@ -74,3 +79,71 @@ def test_cli_download(id_list, file_format, ids, table):
7479
assert result.exit_code == 0
7580
assert len(os.listdir(test_download_dir)) > 0
7681
cleanup(test_download_dir)
82+
83+
84+
# utility function testing
85+
86+
@pytest.mark.parametrize("files_raises", [{"file": "test.txt", "raises": False},
87+
{"file": "../test.txt", "raises": True},
88+
{"file": None, "raises": True}])
89+
def test_tar_unpack(files_raises):
90+
"""
91+
Test the unpacking of a tarfile
92+
"""
93+
filename = files_raises["file"]
94+
raises = files_raises["raises"]
95+
if filename is None: # To hit FileNotFoundError branch
96+
filename = 'test.txt'
97+
real_exists = os.path.exists
98+
def mock_exists(path):
99+
if path.endswith(filename):
100+
return False
101+
return real_exists(path)
102+
exists_patcher = patch('os.path.exists', mock_exists)
103+
exists_patcher.start()
104+
105+
# Create a some tarfile with known content
106+
with tempfile.NamedTemporaryFile(delete=False, suffix='.tar.gz') as tmp:
107+
tar_path = tmp.name
108+
with tarfile.open(tar_path, "w:gz") as tar:
109+
info = tarfile.TarInfo(name=filename)
110+
content = b"Hello, World!"
111+
info.size = len(content)
112+
temp_content_file = tempfile.NamedTemporaryFile(delete=False)
113+
try:
114+
temp_content_file.write(content)
115+
temp_content_file.close()
116+
tar.add(temp_content_file.name, arcname=filename)
117+
finally:
118+
os.remove(temp_content_file.name)
119+
120+
test_download_dir = './.pytest_downloads/'
121+
mkdir(test_download_dir)
122+
assert len(os.listdir(test_download_dir)) == 0
123+
124+
# Mock the requests part to return our tarfile
125+
with patch('hepdata_cli.api.is_downloadable', return_value=True), \
126+
patch('hepdata_cli.api.resilient_requests') as mock_requests, \
127+
patch('hepdata_cli.api.getFilename_fromCd', return_value='test.tar.gz'):
128+
129+
mock_response = mock_requests.return_value
130+
mock_response.content = open(tar_path, 'rb').read()
131+
mock_response.headers = {'content-disposition': 'filename=test.tar.gz'}
132+
133+
# Test the download_url function
134+
try:
135+
if raises:
136+
with pytest.raises(Exception):
137+
files = download_url('http://example.com/test.tar.gz', test_download_dir)
138+
else:
139+
files = download_url('http://example.com/test.tar.gz', test_download_dir)
140+
assert len(files) == 1
141+
for f in files:
142+
assert os.path.exists(f)
143+
with open(f, 'rb') as fr:
144+
assert fr.read() == b"Hello, World!"
145+
finally:
146+
exists_patcher.stop() if filename is None else None
147+
if os.path.exists(tar_path):
148+
os.remove(tar_path)
149+
cleanup(test_download_dir)

0 commit comments

Comments
 (0)