Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 3 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -99,6 +99,8 @@ client.fetch_names(id_list, ids)
client.upload(path_to_file, email, recid, invitation_cookie, sandbox, password)

```
`client.find()` takes the keyword argument `format` to specify which format from `str`, `list`, `set`, or `tuple` shall be returned.
Default is `str`.

## Examples

Expand Down Expand Up @@ -188,8 +190,7 @@ Then,
```python
import hepdata_cli
hepdata_client = hepdata_cli.Client()
id_list = hepdata_client.find('reactions:"P P --> LQ LQ X"', ids='arxiv')
id_list = id_list.split()
id_list = hepdata_client.find('reactions:"P P --> LQ LQ X"', ids='arxiv', format=list)
print(id_list) # ['1605.06035', '2101.11582', ...]

import arxiv
Expand Down
22 changes: 16 additions & 6 deletions hepdata_cli/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,15 +32,18 @@ def __init__(self, verbose=False):
# check service availability
resilient_requests('get', SITE_URL + '/ping')

def find(self, query, keyword=None, ids=None, max_matches=MAX_MATCHES, matches_per_page=MATCHES_PER_PAGE):
def find(self, query, keyword=None, ids=None, max_matches=MAX_MATCHES, matches_per_page=MATCHES_PER_PAGE, format=str):
"""
Search function for the hepdata database. Calls hepdata.net search function.

:param query: string passed to hepdata.net search function. See advanced search tips at hepdata.net.
:param keyword: filters return dictionary for given keyword. Exact match is first attempted, otherwise partial match is accepted.
:param ids: accepts one of ("arxiv", "inspire", "hepdata").
:param max_matches: maximum number of matches to return. Default is 10,000.
:param matches_per_page: number of matches per page. Default is 10.
:param format: specifies the return format if 'ids' is specified. Allowed formats are: str, list, set, tuple. Default is str.

:return: returns a list of (filtered if 'keyword' is specified) dictionaries for the search matches. If 'ids' is specified it instead returns a list of ids as a string.
:return: returns a list of (filtered if 'keyword' is specified) dictionaries for the search matches. If 'ids' is specified it instead returns a list of ids in the format 'format'.
"""
find_results = []
for counter in range(int(max_matches / matches_per_page)):
Expand All @@ -53,7 +56,7 @@ def find(self, query, keyword=None, ids=None, max_matches=MAX_MATCHES, matches_p
# return full list of dictionary
find_results += data['results']
else:
assert ids in [None, "arxiv", "inspire", "hepdata", "id"], "allowd ids are: arxiv, inspire and hepdata"
assert ids in [None, "arxiv", "inspire", "hepdata", "id"], "allowed ids are: arxiv, inspire and hepdata"
if ids is not None:
if ids == "hepdata":
ids = "id"
Expand All @@ -76,7 +79,14 @@ def find(self, query, keyword=None, ids=None, max_matches=MAX_MATCHES, matches_p
if ids is None:
return find_results
else:
return ' '.join(find_results)
if format==str:
Copy link

Copilot AI Oct 21, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Missing spaces around the equality operator. Should be if format == str: to follow PEP 8 style guidelines.

Suggested change
if format==str:
if format == str:

Copilot uses AI. Check for mistakes.
return ' '.join(find_results)
elif format==list:
return find_results
elif format in (set, tuple):
return format(find_results)
else:
raise TypeError(f"Cannot return results in specfied format: {format}. Allowed formats are: {str}, {list}, {set}, {tuple}.")
Copy link

Copilot AI Oct 21, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Corrected spelling of 'specfied' to 'specified'.

Suggested change
raise TypeError(f"Cannot return results in specfied format: {format}. Allowed formats are: {str}, {list}, {set}, {tuple}.")
raise TypeError(f"Cannot return results in specified format: {format}. Allowed formats are: {str}, {list}, {set}, {tuple}.")

Copilot uses AI. Check for mistakes.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This line is not covered by the tests, resulting in a -0.4% decrease in test coverage for the PR. Please add a test that raises the TypeError exception.

Comment on lines +88 to +89
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Typo in error message: 'specfied' should be 'specified'. This is a user-facing error message that will appear unprofessional with the typo.

Did we get this right? 👍 / 👎 to inform future reviews.


def download(self, id_list, file_format=None, ids=None, table_name='', download_dir='./hepdata-downloads'):
"""
Expand Down Expand Up @@ -145,14 +155,14 @@ def _build_urls(self, id_list, file_format, ids, table_name):
"""
Builds urls for download and fetch_names, given the specified parameters.

:param id_list: list of ids to download.
:param id_list: list of ids to download. Format is tuple, list, set or space-separated string.
:param file_format: accepts one of ('csv', 'root', 'yaml', 'yoda', 'yoda1', 'yoda.h5', 'json').
:param ids: accepts one of ('inspire', 'hepdata').
:param table_name: restricts download to specific tables.

:return: dictionary mapping id to url.
"""
if type(id_list) not in (tuple, list):
if isinstance(id_list, str):
id_list = id_list.split()
assert len(id_list) > 0, 'Ids are required.'
assert file_format in ALLOWED_FORMATS, f"allowed formats are: {ALLOWED_FORMATS}"
Expand Down
2 changes: 1 addition & 1 deletion hepdata_cli/version.py
Original file line number Diff line number Diff line change
@@ -1 +1 @@
__version__ = "0.3.0"
__version__ = "0.3.1"
33 changes: 27 additions & 6 deletions tests/test_download.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,13 +38,23 @@ def cleanup(directory):

test_api_download_arguments = [
(["73322"], "json", "hepdata", ''),
(["1222326", "1694381", "1462258", "1309874"], "csv", "inspire", ''),
("1222326 1694381 1462258 1309874", "csv", "inspire", ''), # str
(["1222326", "1694381", "1462258", "1309874"], "csv", "inspire", ''), # list
({"1222326", "1694381", "1462258", "1309874"}, "csv", "inspire", ''), # set
(("1222326", "1694381", "1462258", "1309874"), "csv", "inspire", ''), # tuple
(["61434"], "yaml", "hepdata", "Table1"),
(["1762350"], "yoda", "inspire", "Number density and Sum p_T pT>0.15 GeV/c"),
(["2862529"], "yoda.h5", "inspire", "95% CL upper limit on XSEC times BF"),
(["2862529"], "yoda.h5", "inspire", '')
]

test_api_find_download_arguments = [
("json", "hepdata", str),
("csv", "inspire", list),
("json", "inspire", set),
("csv", "hepdata", tuple),
]

test_cli_download_arguments = [
(["2862529"], "json", "inspire", ''),
(["1222326", "1694381", "1462258", "1309874"], "root", "inspire", ''),
Expand All @@ -54,18 +64,29 @@ def cleanup(directory):

# api testing

def download_and_test(client, id_list, file_format, ids, table, test_download_dir):
path_map = client.download(id_list, file_format, ids, table, test_download_dir)
file_paths = [fp for fps in path_map.values() for fp in fps]
assert len(os.listdir(test_download_dir)) > 0
assert all(os.path.exists(fp) for fp in file_paths)
cleanup(test_download_dir)

@pytest.mark.parametrize("id_list, file_format, ids, table", test_api_download_arguments)
def test_api_download(id_list, file_format, ids, table):
test_download_dir = './.pytest_downloads/'
mkdir(test_download_dir)
assert len(os.listdir(test_download_dir)) == 0
client = Client(verbose=True)
path_map = client.download(id_list, file_format, ids, table, test_download_dir)
file_paths = [fp for fps in path_map.values() for fp in fps]
assert len(os.listdir(test_download_dir)) > 0
assert all(os.path.exists(fp) for fp in file_paths)
cleanup(test_download_dir)
download_and_test(client, id_list, file_format, ids, table, test_download_dir)

@pytest.mark.parametrize("file_format, ids, format", test_api_find_download_arguments)
def test_api_find_download(file_format, ids, format):
test_download_dir = './.pytest_downloads/'
mkdir(test_download_dir)
assert len(os.listdir(test_download_dir)) == 0
client = Client(verbose=True)
id_list = client.find('reactions:"P P --> LQ LQ"', ids=ids, format=format)
download_and_test(client, id_list, file_format, ids, '', test_download_dir)

# cli testing

Expand Down
20 changes: 11 additions & 9 deletions tests/test_search.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,16 +4,19 @@

from click.testing import CliRunner

from hepdata_cli.api import Client
from hepdata_cli.api import Client, MAX_MATCHES, MATCHES_PER_PAGE
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why import MAX_MATCHES and MATCHES_PER_PAGE here? They're not used in this file.

from hepdata_cli.cli import cli


# arguments for testing

test_api_find_arguments = [
('reactions:"P P --> LQ LQ X"', None, None),
('reactions:"P P --> LQ LQ"', 'year', None),
('phrases:"(diffractive AND elastic)"', None, 'arxiv'),
('reactions:"P P --> LQ LQ X"', None, None, None),
('reactions:"P P --> LQ LQ"', 'year', None, None),
('phrases:"(diffractive AND elastic)"', None, 'arxiv', str),
('phrases:"(diffractive AND elastic)"', None, 'hepdata', list),
('reactions:"P P --> LQ LQ X"', None, 'arxiv', set),
('reactions:"P P --> LQ LQ X"', None, 'inspire', tuple),
]

test_cli_find_arguments = [
Expand All @@ -24,17 +27,16 @@

# api test

@pytest.mark.parametrize("query, keyword, ids", test_api_find_arguments)
def test_api_find(query, keyword, ids):
@pytest.mark.parametrize("query, keyword, ids, format", test_api_find_arguments)
def test_api_find(query, keyword, ids, format):
client = Client(verbose=True)
search_result = client.find(query, keyword, ids)
search_result = client.find(query, keyword, ids, format=format)
if ids is None:
assert type(search_result) is list
if len(search_result) > 0:
assert all([type(entry) is dict for entry in search_result])
else:
assert type(search_result) is str

assert type(search_result) is format

# cli testing

Expand Down