diff --git a/README.md b/README.md index 6317015..0b9bbd0 100644 --- a/README.md +++ b/README.md @@ -155,10 +155,11 @@ or equivalently ```python id_list = client.find('reactions:"P P--> LQ LQ"', ids='inspire') -client.download(id_list, ids='inspire', file_format='csv') +downloads = client.download(id_list, ids='inspire', file_format='csv') +print(downloads) # {'1222326': ['./hepdata-downloads/HEPData-ins1222326-v1-csv/Table1.csv', ...], ...} ``` -downloads four .tar.gz archives containing csv files and unpacks them in the default ```./hepdata-downloads``` directory. +downloads four .tar.gz archives containing csv files and unpacks them in the default ```./hepdata-downloads``` directory. Using the API, a dictionary mapping ids to the downloaded files is returned. ### Example 5 - find table names in records: diff --git a/hepdata_cli/api.py b/hepdata_cli/api.py index bf263c2..60e7ec6 100644 --- a/hepdata_cli/api.py +++ b/hepdata_cli/api.py @@ -87,13 +87,19 @@ def download(self, id_list, file_format=None, ids=None, table_name='', download_ :param ids: accepts one of ('inspire', 'hepdata'). It specifies what type of ids have been passed. :param table_name: restricts download to specific tables. :param download_dir: defaults to ./hepdata-downloads. Specifies where to download the files. + + :return: dictionary mapping id to list of downloaded files. + :rtype: dict[int, list[str]] """ - urls = self._build_urls(id_list, file_format, ids, table_name) - for url in urls: + url_map = self._build_urls(id_list, file_format, ids, table_name) + file_map = {} + for record_id, url in url_map.items(): if self.verbose is True: print("Downloading: " + url) - download_url(url, download_dir) + files_downloaded = download_url(url, download_dir) + file_map[record_id] = files_downloaded + return file_map def fetch_names(self, id_list, ids=None): """ @@ -102,9 +108,9 @@ def fetch_names(self, id_list, ids=None): :param id_list: list of id of records of which to return table names. :param ids: accepts one of ('inspire', 'hepdata'). It specifies what type of ids have been passed. """ - urls = self._build_urls(id_list, 'json', ids, '') + url_map = self._build_urls(id_list, 'json', ids, '') table_names = [] - for url in urls: + for url in url_map.values(): response = resilient_requests('get', url) json_dict = response.json() table_names += [[data_table['name'] for data_table in json_dict['data_tables']]] @@ -136,7 +142,16 @@ def upload(self, path_to_file, email, recid=None, invitation_cookie=None, sandbo print('Uploaded ' + path_to_file + ' to ' + SITE_URL + '/record/' + str(recid)) def _build_urls(self, id_list, file_format, ids, table_name): - """Builds urls for download and fetch_names, given the specified parameters.""" + """ + Builds urls for download and fetch_names, given the specified parameters. + + :param id_list: list of ids to download. + :param file_format: accepts one of ('csv', 'root', 'yaml', 'yoda', 'yoda1', 'yoda.h5', 'json'). + :param ids: accepts one of ('inspire', 'hepdata'). + :param table_name: restricts download to specific tables. + + :return: dictionary mapping id to url. + """ if type(id_list) not in (tuple, list): id_list = id_list.split() assert len(id_list) > 0, 'Ids are required.' @@ -146,9 +161,12 @@ def _build_urls(self, id_list, file_format, ids, table_name): params = {'format': file_format} else: params = {'format': file_format, 'table': table_name} - urls = [resilient_requests('get', SITE_URL + '/record/' + ('ins' if ids == 'inspire' else '') + id_entry, params=params).url.replace('%2525', '%25') for id_entry in id_list] + url_mapping = {} + for id_entry in id_list: + url = resilient_requests('get', SITE_URL + '/record/' + ('ins' if ids == 'inspire' else '') + id_entry, params=params).url.replace('%2525', '%25') + url_mapping[id_entry] = url # TODO: Investigate root cause of double URL encoding (https://github.com/HEPData/hepdata-cli/issues/8). - return urls + return url_mapping def _query(self, query, page, size): """Builds the search query passed to hepdata.net.""" @@ -170,6 +188,7 @@ def mkdir(directory): def download_url(url, download_dir): """Download file and if necessary extract it.""" + files_downloaded = [] assert is_downloadable(url), "Given url is not downloadable: {}".format(url) response = resilient_requests('get', url, allow_redirects=True) if url[-4:] == 'json': @@ -182,10 +201,31 @@ def download_url(url, download_dir): mkdir(os.path.dirname(filepath)) open(filepath, 'wb').write(response.content) if filepath.endswith("tar.gz") or filepath.endswith("tar"): - tar = tarfile.open(filepath, "r:gz" if filepath.endswith("tar.gz") else "r:") - tar.extractall(path=os.path.dirname(filepath)) - tar.close() - os.remove(filepath) + tar = None + try: + tar = tarfile.open(filepath, "r:gz" if filepath.endswith("tar.gz") else "r:") + extract_dir = os.path.abspath(os.path.dirname(filepath)) + tar.extractall(path=os.path.dirname(filepath)) + for member in tar.getmembers(): + if member.isfile(): + extracted_path = os.path.join(os.path.dirname(filepath), member.name) + abs_extracted_path = os.path.abspath(extracted_path) + if abs_extracted_path.startswith(extract_dir + os.sep) and os.path.exists(abs_extracted_path): + files_downloaded.append(abs_extracted_path) + elif not abs_extracted_path.startswith(extract_dir + os.sep): + raise ValueError(f"Attempted path traversal for file {member.name}") + else: + raise FileNotFoundError(f"Extracted file {member.name} not found") + except Exception as e: + raise Exception(f"Failed to extract {filepath}: {str(e)}") + finally: + if tar: + tar.close() + if os.path.exists(filepath): + os.remove(filepath) + else: + files_downloaded.append(filepath) + return files_downloaded def getFilename_fromCd(cd): diff --git a/hepdata_cli/version.py b/hepdata_cli/version.py index d31c31e..493f741 100644 --- a/hepdata_cli/version.py +++ b/hepdata_cli/version.py @@ -1 +1 @@ -__version__ = "0.2.3" +__version__ = "0.3.0" diff --git a/tests/test_download.py b/tests/test_download.py index 28a727b..50ceb3f 100644 --- a/tests/test_download.py +++ b/tests/test_download.py @@ -3,10 +3,13 @@ import pytest import os import shutil +import tarfile +import tempfile +from unittest.mock import patch from click.testing import CliRunner -from hepdata_cli.api import Client, mkdir +from hepdata_cli.api import Client, download_url, mkdir from hepdata_cli.cli import cli @@ -57,8 +60,10 @@ def test_api_download(id_list, file_format, ids, table): mkdir(test_download_dir) assert len(os.listdir(test_download_dir)) == 0 client = Client(verbose=True) - client.download(id_list, file_format, ids, table, test_download_dir) + path_map = client.download(id_list, file_format, ids, table, test_download_dir) + file_paths = [fp for fps in path_map.values() for fp in fps] assert len(os.listdir(test_download_dir)) > 0 + assert all(os.path.exists(fp) for fp in file_paths) cleanup(test_download_dir) @@ -74,3 +79,71 @@ def test_cli_download(id_list, file_format, ids, table): assert result.exit_code == 0 assert len(os.listdir(test_download_dir)) > 0 cleanup(test_download_dir) + + +# utility function testing + +@pytest.mark.parametrize("files_raises", [{"file": "test.txt", "raises": False}, + {"file": "../test.txt", "raises": True}, + {"file": None, "raises": True}]) +def test_tar_unpack(files_raises): + """ + Test the unpacking of a tarfile + """ + filename = files_raises["file"] + raises = files_raises["raises"] + if filename is None: # To hit FileNotFoundError branch + filename = 'test.txt' + real_exists = os.path.exists + def mock_exists(path): + if path.endswith(filename): + return False + return real_exists(path) + exists_patcher = patch('os.path.exists', mock_exists) + exists_patcher.start() + + # Create a some tarfile with known content + with tempfile.NamedTemporaryFile(delete=False, suffix='.tar.gz') as tmp: + tar_path = tmp.name + with tarfile.open(tar_path, "w:gz") as tar: + info = tarfile.TarInfo(name=filename) + content = b"Hello, World!" + info.size = len(content) + temp_content_file = tempfile.NamedTemporaryFile(delete=False) + try: + temp_content_file.write(content) + temp_content_file.close() + tar.add(temp_content_file.name, arcname=filename) + finally: + os.remove(temp_content_file.name) + + test_download_dir = './.pytest_downloads/' + mkdir(test_download_dir) + assert len(os.listdir(test_download_dir)) == 0 + + # Mock the requests part to return our tarfile + with patch('hepdata_cli.api.is_downloadable', return_value=True), \ + patch('hepdata_cli.api.resilient_requests') as mock_requests, \ + patch('hepdata_cli.api.getFilename_fromCd', return_value='test.tar.gz'): + + mock_response = mock_requests.return_value + mock_response.content = open(tar_path, 'rb').read() + mock_response.headers = {'content-disposition': 'filename=test.tar.gz'} + + # Test the download_url function + try: + if raises: + with pytest.raises(Exception): + files = download_url('http://example.com/test.tar.gz', test_download_dir) + else: + files = download_url('http://example.com/test.tar.gz', test_download_dir) + assert len(files) == 1 + for f in files: + assert os.path.exists(f) + with open(f, 'rb') as fr: + assert fr.read() == b"Hello, World!" + finally: + exists_patcher.stop() if filename is None else None + if os.path.exists(tar_path): + os.remove(tar_path) + cleanup(test_download_dir)