Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 3 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -155,10 +155,11 @@ or equivalently

```python
id_list = client.find('reactions:"P P--> LQ LQ"', ids='inspire')
client.download(id_list, ids='inspire', file_format='csv')
downloads = client.download(id_list, ids='inspire', file_format='csv')
print(downloads) # {'1222326': ['./hepdata-downloads/HEPData-ins1222326-v1-csv/Table1.csv', ...], ...}
```

downloads four .tar.gz archives containing csv files and unpacks them in the default ```./hepdata-downloads``` directory.
downloads four .tar.gz archives containing csv files and unpacks them in the default ```./hepdata-downloads``` directory. Using the API, a dictionary mapping ids to the downloaded files is returned.

### Example 5 - find table names in records:

Expand Down
64 changes: 52 additions & 12 deletions hepdata_cli/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,13 +87,19 @@ def download(self, id_list, file_format=None, ids=None, table_name='', download_
:param ids: accepts one of ('inspire', 'hepdata'). It specifies what type of ids have been passed.
:param table_name: restricts download to specific tables.
:param download_dir: defaults to ./hepdata-downloads. Specifies where to download the files.

:return: dictionary mapping id to list of downloaded files.
:rtype: dict[int, list[str]]
"""

urls = self._build_urls(id_list, file_format, ids, table_name)
for url in urls:
url_map = self._build_urls(id_list, file_format, ids, table_name)
file_map = {}
for record_id, url in url_map.items():
if self.verbose is True:
print("Downloading: " + url)
download_url(url, download_dir)
files_downloaded = download_url(url, download_dir)
file_map[record_id] = files_downloaded
return file_map

def fetch_names(self, id_list, ids=None):
"""
Expand All @@ -102,9 +108,9 @@ def fetch_names(self, id_list, ids=None):
:param id_list: list of id of records of which to return table names.
:param ids: accepts one of ('inspire', 'hepdata'). It specifies what type of ids have been passed.
"""
urls = self._build_urls(id_list, 'json', ids, '')
url_map = self._build_urls(id_list, 'json', ids, '')
table_names = []
for url in urls:
for url in url_map.values():
response = resilient_requests('get', url)
json_dict = response.json()
table_names += [[data_table['name'] for data_table in json_dict['data_tables']]]
Expand Down Expand Up @@ -136,7 +142,16 @@ def upload(self, path_to_file, email, recid=None, invitation_cookie=None, sandbo
print('Uploaded ' + path_to_file + ' to ' + SITE_URL + '/record/' + str(recid))

def _build_urls(self, id_list, file_format, ids, table_name):
"""Builds urls for download and fetch_names, given the specified parameters."""
"""
Builds urls for download and fetch_names, given the specified parameters.

:param id_list: list of ids to download.
:param file_format: accepts one of ('csv', 'root', 'yaml', 'yoda', 'yoda1', 'yoda.h5', 'json').
:param ids: accepts one of ('inspire', 'hepdata').
:param table_name: restricts download to specific tables.

:return: dictionary mapping id to url.
"""
if type(id_list) not in (tuple, list):
id_list = id_list.split()
assert len(id_list) > 0, 'Ids are required.'
Expand All @@ -146,9 +161,12 @@ def _build_urls(self, id_list, file_format, ids, table_name):
params = {'format': file_format}
else:
params = {'format': file_format, 'table': table_name}
urls = [resilient_requests('get', SITE_URL + '/record/' + ('ins' if ids == 'inspire' else '') + id_entry, params=params).url.replace('%2525', '%25') for id_entry in id_list]
url_mapping = {}
for id_entry in id_list:
url = resilient_requests('get', SITE_URL + '/record/' + ('ins' if ids == 'inspire' else '') + id_entry, params=params).url.replace('%2525', '%25')
url_mapping[id_entry] = url
# TODO: Investigate root cause of double URL encoding (https://github.com/HEPData/hepdata-cli/issues/8).
return urls
return url_mapping

def _query(self, query, page, size):
"""Builds the search query passed to hepdata.net."""
Expand All @@ -170,6 +188,7 @@ def mkdir(directory):

def download_url(url, download_dir):
"""Download file and if necessary extract it."""
files_downloaded = []
assert is_downloadable(url), "Given url is not downloadable: {}".format(url)
response = resilient_requests('get', url, allow_redirects=True)
if url[-4:] == 'json':
Expand All @@ -182,10 +201,31 @@ def download_url(url, download_dir):
mkdir(os.path.dirname(filepath))
open(filepath, 'wb').write(response.content)
if filepath.endswith("tar.gz") or filepath.endswith("tar"):
tar = tarfile.open(filepath, "r:gz" if filepath.endswith("tar.gz") else "r:")
tar.extractall(path=os.path.dirname(filepath))
tar.close()
os.remove(filepath)
tar = None
try:
tar = tarfile.open(filepath, "r:gz" if filepath.endswith("tar.gz") else "r:")
extract_dir = os.path.abspath(os.path.dirname(filepath))
tar.extractall(path=os.path.dirname(filepath))
for member in tar.getmembers():
if member.isfile():
extracted_path = os.path.join(os.path.dirname(filepath), member.name)
abs_extracted_path = os.path.abspath(extracted_path)
if abs_extracted_path.startswith(extract_dir + os.sep) and os.path.exists(abs_extracted_path):
files_downloaded.append(abs_extracted_path)
elif not abs_extracted_path.startswith(extract_dir + os.sep):
raise ValueError(f"Attempted path traversal for file {member.name}")
else:
raise FileNotFoundError(f"Extracted file {member.name} not found")
except Exception as e:
raise Exception(f"Failed to extract {filepath}: {str(e)}")
finally:
if tar:
tar.close()
if os.path.exists(filepath):
os.remove(filepath)
else:
files_downloaded.append(filepath)
return files_downloaded


def getFilename_fromCd(cd):
Expand Down
2 changes: 1 addition & 1 deletion hepdata_cli/version.py
Original file line number Diff line number Diff line change
@@ -1 +1 @@
__version__ = "0.2.3"
__version__ = "0.3.0"
77 changes: 75 additions & 2 deletions tests/test_download.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,13 @@
import pytest
import os
import shutil
import tarfile
import tempfile
from unittest.mock import patch

from click.testing import CliRunner

from hepdata_cli.api import Client, mkdir
from hepdata_cli.api import Client, download_url, mkdir
from hepdata_cli.cli import cli


Expand Down Expand Up @@ -57,8 +60,10 @@ def test_api_download(id_list, file_format, ids, table):
mkdir(test_download_dir)
assert len(os.listdir(test_download_dir)) == 0
client = Client(verbose=True)
client.download(id_list, file_format, ids, table, test_download_dir)
path_map = client.download(id_list, file_format, ids, table, test_download_dir)
file_paths = [fp for fps in path_map.values() for fp in fps]
assert len(os.listdir(test_download_dir)) > 0
assert all(os.path.exists(fp) for fp in file_paths)
cleanup(test_download_dir)


Expand All @@ -74,3 +79,71 @@ def test_cli_download(id_list, file_format, ids, table):
assert result.exit_code == 0
assert len(os.listdir(test_download_dir)) > 0
cleanup(test_download_dir)


# utility function testing

@pytest.mark.parametrize("files_raises", [{"file": "test.txt", "raises": False},
{"file": "../test.txt", "raises": True},
{"file": None, "raises": True}])
def test_tar_unpack(files_raises):
"""
Test the unpacking of a tarfile
"""
filename = files_raises["file"]
raises = files_raises["raises"]
if filename is None: # To hit FileNotFoundError branch
filename = 'test.txt'
real_exists = os.path.exists
def mock_exists(path):
if path.endswith(filename):
return False
return real_exists(path)
exists_patcher = patch('os.path.exists', mock_exists)
exists_patcher.start()

# Create a some tarfile with known content
with tempfile.NamedTemporaryFile(delete=False, suffix='.tar.gz') as tmp:
tar_path = tmp.name
with tarfile.open(tar_path, "w:gz") as tar:
info = tarfile.TarInfo(name=filename)
content = b"Hello, World!"
info.size = len(content)
temp_content_file = tempfile.NamedTemporaryFile(delete=False)
try:
temp_content_file.write(content)
temp_content_file.close()
tar.add(temp_content_file.name, arcname=filename)
finally:
os.remove(temp_content_file.name)

test_download_dir = './.pytest_downloads/'
mkdir(test_download_dir)
assert len(os.listdir(test_download_dir)) == 0

# Mock the requests part to return our tarfile
with patch('hepdata_cli.api.is_downloadable', return_value=True), \
patch('hepdata_cli.api.resilient_requests') as mock_requests, \
patch('hepdata_cli.api.getFilename_fromCd', return_value='test.tar.gz'):

mock_response = mock_requests.return_value
mock_response.content = open(tar_path, 'rb').read()
mock_response.headers = {'content-disposition': 'filename=test.tar.gz'}

# Test the download_url function
try:
if raises:
with pytest.raises(Exception):
files = download_url('http://example.com/test.tar.gz', test_download_dir)
else:
files = download_url('http://example.com/test.tar.gz', test_download_dir)
assert len(files) == 1
for f in files:
assert os.path.exists(f)
with open(f, 'rb') as fr:
assert fr.read() == b"Hello, World!"
finally:
exists_patcher.stop() if filename is None else None
if os.path.exists(tar_path):
os.remove(tar_path)
cleanup(test_download_dir)