diff --git a/README.md b/README.md index 340ef59..e504c33 100644 --- a/README.md +++ b/README.md @@ -99,6 +99,8 @@ client.fetch_names(id_list, ids) client.upload(path_to_file, email, recid, invitation_cookie, sandbox, password) ``` +`client.find()` takes the keyword argument `format` to specify which format from `str`, `list`, `set`, or `tuple` shall be returned. +Default is `str`. ## Examples @@ -188,8 +190,7 @@ Then, ```python import hepdata_cli hepdata_client = hepdata_cli.Client() -id_list = hepdata_client.find('reactions:"P P --> LQ LQ X"', ids='arxiv') -id_list = id_list.split() +id_list = hepdata_client.find('reactions:"P P --> LQ LQ X"', ids='arxiv', format=list) print(id_list) # ['1605.06035', '2101.11582', ...] import arxiv diff --git a/hepdata_cli/api.py b/hepdata_cli/api.py index 60e7ec6..91252ad 100644 --- a/hepdata_cli/api.py +++ b/hepdata_cli/api.py @@ -32,15 +32,18 @@ def __init__(self, verbose=False): # check service availability resilient_requests('get', SITE_URL + '/ping') - def find(self, query, keyword=None, ids=None, max_matches=MAX_MATCHES, matches_per_page=MATCHES_PER_PAGE): + def find(self, query, keyword=None, ids=None, max_matches=MAX_MATCHES, matches_per_page=MATCHES_PER_PAGE, format=str): """ Search function for the hepdata database. Calls hepdata.net search function. :param query: string passed to hepdata.net search function. See advanced search tips at hepdata.net. :param keyword: filters return dictionary for given keyword. Exact match is first attempted, otherwise partial match is accepted. :param ids: accepts one of ("arxiv", "inspire", "hepdata"). + :param max_matches: maximum number of matches to return. Default is 10,000. + :param matches_per_page: number of matches per page. Default is 10. + :param format: specifies the return format if 'ids' is specified. Allowed formats are: str, list, set, tuple. Default is str. - :return: returns a list of (filtered if 'keyword' is specified) dictionaries for the search matches. If 'ids' is specified it instead returns a list of ids as a string. + :return: returns a list of (filtered if 'keyword' is specified) dictionaries for the search matches. If 'ids' is specified it instead returns a list of ids in the format 'format'. """ find_results = [] for counter in range(int(max_matches / matches_per_page)): @@ -53,7 +56,7 @@ def find(self, query, keyword=None, ids=None, max_matches=MAX_MATCHES, matches_p # return full list of dictionary find_results += data['results'] else: - assert ids in [None, "arxiv", "inspire", "hepdata", "id"], "allowd ids are: arxiv, inspire and hepdata" + assert ids in [None, "arxiv", "inspire", "hepdata", "id"], "allowed ids are: arxiv, inspire and hepdata" if ids is not None: if ids == "hepdata": ids = "id" @@ -76,7 +79,14 @@ def find(self, query, keyword=None, ids=None, max_matches=MAX_MATCHES, matches_p if ids is None: return find_results else: - return ' '.join(find_results) + if format==str: + return ' '.join(find_results) + elif format==list: + return find_results + elif format in (set, tuple): + return format(find_results) + else: + raise TypeError(f"Cannot return results in specfied format: {format}. Allowed formats are: {str}, {list}, {set}, {tuple}.") def download(self, id_list, file_format=None, ids=None, table_name='', download_dir='./hepdata-downloads'): """ @@ -145,14 +155,14 @@ def _build_urls(self, id_list, file_format, ids, table_name): """ Builds urls for download and fetch_names, given the specified parameters. - :param id_list: list of ids to download. + :param id_list: list of ids to download. Format is tuple, list, set or space-separated string. :param file_format: accepts one of ('csv', 'root', 'yaml', 'yoda', 'yoda1', 'yoda.h5', 'json'). :param ids: accepts one of ('inspire', 'hepdata'). :param table_name: restricts download to specific tables. :return: dictionary mapping id to url. """ - if type(id_list) not in (tuple, list): + if isinstance(id_list, str): id_list = id_list.split() assert len(id_list) > 0, 'Ids are required.' assert file_format in ALLOWED_FORMATS, f"allowed formats are: {ALLOWED_FORMATS}" diff --git a/hepdata_cli/version.py b/hepdata_cli/version.py index 493f741..260c070 100644 --- a/hepdata_cli/version.py +++ b/hepdata_cli/version.py @@ -1 +1 @@ -__version__ = "0.3.0" +__version__ = "0.3.1" diff --git a/tests/test_download.py b/tests/test_download.py index 50ceb3f..4e834f8 100644 --- a/tests/test_download.py +++ b/tests/test_download.py @@ -38,13 +38,23 @@ def cleanup(directory): test_api_download_arguments = [ (["73322"], "json", "hepdata", ''), - (["1222326", "1694381", "1462258", "1309874"], "csv", "inspire", ''), + ("1222326 1694381 1462258 1309874", "csv", "inspire", ''), # str + (["1222326", "1694381", "1462258", "1309874"], "csv", "inspire", ''), # list + ({"1222326", "1694381", "1462258", "1309874"}, "csv", "inspire", ''), # set + (("1222326", "1694381", "1462258", "1309874"), "csv", "inspire", ''), # tuple (["61434"], "yaml", "hepdata", "Table1"), (["1762350"], "yoda", "inspire", "Number density and Sum p_T pT>0.15 GeV/c"), (["2862529"], "yoda.h5", "inspire", "95% CL upper limit on XSEC times BF"), (["2862529"], "yoda.h5", "inspire", '') ] +test_api_find_download_arguments = [ + ("json", "hepdata", str), + ("csv", "inspire", list), + ("json", "inspire", set), + ("csv", "hepdata", tuple), +] + test_cli_download_arguments = [ (["2862529"], "json", "inspire", ''), (["1222326", "1694381", "1462258", "1309874"], "root", "inspire", ''), @@ -54,18 +64,29 @@ def cleanup(directory): # api testing +def download_and_test(client, id_list, file_format, ids, table, test_download_dir): + path_map = client.download(id_list, file_format, ids, table, test_download_dir) + file_paths = [fp for fps in path_map.values() for fp in fps] + assert len(os.listdir(test_download_dir)) > 0 + assert all(os.path.exists(fp) for fp in file_paths) + cleanup(test_download_dir) + @pytest.mark.parametrize("id_list, file_format, ids, table", test_api_download_arguments) def test_api_download(id_list, file_format, ids, table): test_download_dir = './.pytest_downloads/' mkdir(test_download_dir) assert len(os.listdir(test_download_dir)) == 0 client = Client(verbose=True) - path_map = client.download(id_list, file_format, ids, table, test_download_dir) - file_paths = [fp for fps in path_map.values() for fp in fps] - assert len(os.listdir(test_download_dir)) > 0 - assert all(os.path.exists(fp) for fp in file_paths) - cleanup(test_download_dir) + download_and_test(client, id_list, file_format, ids, table, test_download_dir) +@pytest.mark.parametrize("file_format, ids, format", test_api_find_download_arguments) +def test_api_find_download(file_format, ids, format): + test_download_dir = './.pytest_downloads/' + mkdir(test_download_dir) + assert len(os.listdir(test_download_dir)) == 0 + client = Client(verbose=True) + id_list = client.find('reactions:"P P --> LQ LQ"', ids=ids, format=format) + download_and_test(client, id_list, file_format, ids, '', test_download_dir) # cli testing diff --git a/tests/test_search.py b/tests/test_search.py index ccf5633..82f89d7 100644 --- a/tests/test_search.py +++ b/tests/test_search.py @@ -4,16 +4,19 @@ from click.testing import CliRunner -from hepdata_cli.api import Client +from hepdata_cli.api import Client, MAX_MATCHES, MATCHES_PER_PAGE from hepdata_cli.cli import cli # arguments for testing test_api_find_arguments = [ - ('reactions:"P P --> LQ LQ X"', None, None), - ('reactions:"P P --> LQ LQ"', 'year', None), - ('phrases:"(diffractive AND elastic)"', None, 'arxiv'), + ('reactions:"P P --> LQ LQ X"', None, None, None), + ('reactions:"P P --> LQ LQ"', 'year', None, None), + ('phrases:"(diffractive AND elastic)"', None, 'arxiv', str), + ('phrases:"(diffractive AND elastic)"', None, 'hepdata', list), + ('reactions:"P P --> LQ LQ X"', None, 'arxiv', set), + ('reactions:"P P --> LQ LQ X"', None, 'inspire', tuple), ] test_cli_find_arguments = [ @@ -24,17 +27,16 @@ # api test -@pytest.mark.parametrize("query, keyword, ids", test_api_find_arguments) -def test_api_find(query, keyword, ids): +@pytest.mark.parametrize("query, keyword, ids, format", test_api_find_arguments) +def test_api_find(query, keyword, ids, format): client = Client(verbose=True) - search_result = client.find(query, keyword, ids) + search_result = client.find(query, keyword, ids, format=format) if ids is None: assert type(search_result) is list if len(search_result) > 0: assert all([type(entry) is dict for entry in search_result]) else: - assert type(search_result) is str - + assert type(search_result) is format # cli testing