Skip to content

Commit 5958f17

Browse files
Wauplinosanseviero
andauthored
Return more information in create_commit output (#1066)
* Return more information in create_commit output * flake8 * requested changes * fix autocomplete test * Add pr_revision and pr_url to CommitInfo * Update tests/test_hf_api.py Co-authored-by: Omar Sanseviero <[email protected]> * nicely handle properties in dataclass * make style Co-authored-by: Omar Sanseviero <[email protected]>
1 parent e8801bd commit 5958f17

File tree

7 files changed

+134
-21
lines changed

7 files changed

+134
-21
lines changed

docs/source/package_reference/hf_api.mdx

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,16 +25,30 @@ models = hf_api.list_models()
2525

2626
Using the `HfApi` class directly enables you to set a different endpoint to that of the Hugging Face's Hub.
2727

28+
### HfApi
29+
2830
[[autodoc]] HfApi
2931

32+
### ModelInfo
33+
3034
[[autodoc]] huggingface_hub.hf_api.ModelInfo
3135

36+
### DatasetInfo
37+
3238
[[autodoc]] huggingface_hub.hf_api.DatasetInfo
3339

40+
### SpaceInfo
41+
3442
[[autodoc]] huggingface_hub.hf_api.SpaceInfo
3543

44+
### RepoFile
45+
3646
[[autodoc]] huggingface_hub.hf_api.RepoFile
3747

48+
### CommitInfo
49+
50+
[[autodoc]] huggingface_hub.hf_api.CommitInfo
51+
3852
## `create_commit` API
3953

4054
Below are the supported values for [`CommitOperation`]:
@@ -56,10 +70,18 @@ It does this using the [`HfFolder`] utility, which saves data at the root of the
5670

5771
Some helpers to filter repositories on the Hub are available in the `huggingface_hub` package.
5872

73+
### DatasetFilter
74+
5975
[[autodoc]] DatasetFilter
6076

77+
### ModelFilter
78+
6179
[[autodoc]] ModelFilter
6280

81+
### DatasetSearchArguments
82+
6383
[[autodoc]] DatasetSearchArguments
6484

85+
### ModelSearchArguments
86+
6587
[[autodoc]] ModelSearchArguments

src/huggingface_hub/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -93,6 +93,7 @@
9393
"try_to_load_from_cache",
9494
],
9595
"hf_api": [
96+
"CommitInfo",
9697
"CommitOperation",
9798
"CommitOperationAdd",
9899
"CommitOperationDelete",
@@ -306,6 +307,7 @@ def __dir__():
306307
from .file_download import hf_hub_download # noqa: F401
307308
from .file_download import hf_hub_url # noqa: F401
308309
from .file_download import try_to_load_from_cache # noqa: F401
310+
from .hf_api import CommitInfo # noqa: F401
309311
from .hf_api import CommitOperation # noqa: F401
310312
from .hf_api import CommitOperationAdd # noqa: F401
311313
from .hf_api import CommitOperationDelete # noqa: F401

src/huggingface_hub/hf_api.py

Lines changed: 76 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
import re
1717
import subprocess
1818
import warnings
19+
from dataclasses import dataclass, field
1920
from typing import BinaryIO, Dict, Iterable, Iterator, List, Optional, Tuple, Union
2021
from urllib.parse import quote
2122

@@ -171,6 +172,62 @@ class BlobLfsInfo(TypedDict, total=False):
171172
sha256: str
172173

173174

175+
@dataclass
176+
class CommitInfo:
177+
"""Data structure containing information about a newly created commit.
178+
179+
Returned by [`create_commit`].
180+
181+
Args:
182+
commit_url (`str`):
183+
Url where to find the commit.
184+
185+
commit_message (`str`):
186+
The summary (first line) of the commit that has been created.
187+
188+
commit_description (`str`):
189+
Description of the commit that has been created. Can be empty.
190+
191+
oid (`str`):
192+
Commit hash id. Example: `"91c54ad1727ee830252e457677f467be0bfd8a57"`.
193+
194+
pr_url (`str`, *optional*):
195+
Url to the PR that has been created, if any. Populated when `create_pr=True`
196+
is passed.
197+
198+
pr_revision (`str`, *optional*):
199+
Revision of the PR that has been created, if any. Populated when
200+
`create_pr=True` is passed. Example: `"refs/pr/1"`.
201+
202+
pr_num (`int`, *optional*):
203+
Number of the PR discussion that has been created, if any. Populated when
204+
`create_pr=True` is passed. Can be passed as `discussion_num` in
205+
[`get_discussion_details`]. Example: `1`.
206+
"""
207+
208+
commit_url: str
209+
commit_message: str
210+
commit_description: str
211+
oid: str
212+
pr_url: Optional[str] = None
213+
214+
# Computed from `pr_url` in `__post_init__`
215+
pr_revision: Optional[str] = field(init=False)
216+
pr_num: Optional[str] = field(init=False)
217+
218+
def __post_init__(self):
219+
"""Populate pr-related fields after initialization.
220+
221+
See https://docs.python.org/3.10/library/dataclasses.html#post-init-processing.
222+
"""
223+
if self.pr_url is not None:
224+
self.pr_revision = _parse_revision_from_pr_url(self.pr_url)
225+
self.pr_num = int(self.pr_revision.split("/")[-1])
226+
else:
227+
self.pr_revision = None
228+
self.pr_num = None
229+
230+
174231
class RepoFile:
175232
"""
176233
Data structure that represents a public file inside a repo, accessible from
@@ -1850,7 +1907,7 @@ def create_commit(
18501907
create_pr: Optional[bool] = None,
18511908
num_threads: int = 5,
18521909
parent_commit: Optional[str] = None,
1853-
) -> Optional[str]:
1910+
) -> CommitInfo:
18541911
"""
18551912
Creates a commit in the given repo, deleting & uploading files as needed.
18561913
@@ -1902,9 +1959,9 @@ def create_commit(
19021959
if the repo is updated / committed to concurrently.
19031960
19041961
Returns:
1905-
`str` or `None`:
1906-
If `create_pr` is `True`, returns the URL to the newly created Pull Request
1907-
on the Hub. Otherwise returns `None`.
1962+
[`CommitInfo`]:
1963+
Instance of [`CommitInfo`] containing information about the newly
1964+
created commit (commit hash, commit url, pr url, commit message,...).
19081965
19091966
Raises:
19101967
[`ValueError`](https://docs.python.org/3/library/exceptions.html#ValueError)
@@ -2015,7 +2072,14 @@ def create_commit(
20152072
params={"create_pr": "1"} if create_pr else None,
20162073
)
20172074
hf_raise_for_status(commit_resp, endpoint_name="commit")
2018-
return commit_resp.json().get("pullRequestUrl", None)
2075+
commit_data = commit_resp.json()
2076+
return CommitInfo(
2077+
commit_url=commit_data["commitUrl"],
2078+
commit_message=commit_message,
2079+
commit_description=commit_description,
2080+
oid=commit_data["commitOid"],
2081+
pr_url=commit_data["pullRequestUrl"] if create_pr else None,
2082+
)
20192083

20202084
@validate_hf_hub_args
20212085
def upload_file(
@@ -2157,7 +2221,7 @@ def upload_file(
21572221
path_in_repo=path_in_repo,
21582222
)
21592223

2160-
pr_url = self.create_commit(
2224+
commit_info = self.create_commit(
21612225
repo_id=repo_id,
21622226
repo_type=repo_type,
21632227
operations=[operation],
@@ -2169,8 +2233,8 @@ def upload_file(
21692233
parent_commit=parent_commit,
21702234
)
21712235

2172-
if pr_url is not None:
2173-
revision = quote(_parse_revision_from_pr_url(pr_url), safe="")
2236+
if commit_info.pr_url is not None:
2237+
revision = quote(_parse_revision_from_pr_url(commit_info.pr_url), safe="")
21742238
if repo_type in REPO_TYPES_URL_PREFIXES:
21752239
repo_id = REPO_TYPES_URL_PREFIXES[repo_type] + repo_id
21762240
revision = revision if revision is not None else DEFAULT_REVISION
@@ -2317,7 +2381,7 @@ def upload_folder(
23172381
ignore_patterns=ignore_patterns,
23182382
)
23192383

2320-
pr_url = self.create_commit(
2384+
commit_info = self.create_commit(
23212385
repo_type=repo_type,
23222386
repo_id=repo_id,
23232387
operations=files_to_add,
@@ -2329,8 +2393,8 @@ def upload_folder(
23292393
parent_commit=parent_commit,
23302394
)
23312395

2332-
if pr_url is not None:
2333-
revision = quote(_parse_revision_from_pr_url(pr_url), safe="")
2396+
if commit_info.pr_url is not None:
2397+
revision = quote(_parse_revision_from_pr_url(commit_info.pr_url), safe="")
23342398
if repo_type in REPO_TYPES_URL_PREFIXES:
23352399
repo_id = REPO_TYPES_URL_PREFIXES[repo_type] + repo_id
23362400
revision = revision if revision is not None else DEFAULT_REVISION
@@ -2350,7 +2414,7 @@ def delete_file(
23502414
commit_description: Optional[str] = None,
23512415
create_pr: Optional[bool] = None,
23522416
parent_commit: Optional[str] = None,
2353-
):
2417+
) -> CommitInfo:
23542418
"""
23552419
Deletes a file in the given repo.
23562420

src/huggingface_hub/keras_mixin.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -446,7 +446,7 @@ def push_to_hub_keras(
446446
allow_patterns=allow_patterns,
447447
ignore_patterns=ignore_patterns,
448448
)
449-
pr_url = api.create_commit(
449+
commit_info = api.create_commit(
450450
repo_type="model",
451451
repo_id=repo_id,
452452
operations=operations,
@@ -458,8 +458,8 @@ def push_to_hub_keras(
458458
revision = branch
459459
if revision is None:
460460
revision = (
461-
quote(_parse_revision_from_pr_url(pr_url), safe="")
462-
if pr_url is not None
461+
quote(_parse_revision_from_pr_url(commit_info.pr_url), safe="")
462+
if commit_info.pr_url is not None
463463
else DEFAULT_REVISION
464464
)
465465
return f"{api.endpoint}/{repo_id}/tree/{revision}/"

src/huggingface_hub/utils/_validators.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@
1717
import re
1818
from functools import wraps
1919
from itertools import chain
20-
from typing import Callable
20+
from typing import TypeVar
2121

2222

2323
REPO_ID_REGEX = re.compile(
@@ -40,7 +40,11 @@ class HFValidationError(ValueError):
4040
"""
4141

4242

43-
def validate_hf_hub_args(fn: Callable) -> Callable:
43+
# type hint meaning "function signature not changed by decorator"
44+
CallableT = TypeVar("CallableT") # callable type
45+
46+
47+
def validate_hf_hub_args(fn: CallableT) -> CallableT:
4448
"""Validate values received as argument for any public method of `huggingface_hub`.
4549
4650
The goal of this decorator is to harmonize validation of arguments reused

tests/test_hf_api.py

Lines changed: 24 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,7 @@
4141
from huggingface_hub.file_download import cached_download, hf_hub_download
4242
from huggingface_hub.hf_api import (
4343
USERNAME_PLACEHOLDER,
44+
CommitInfo,
4445
DatasetInfo,
4546
DatasetSearchArguments,
4647
HfApi,
@@ -766,10 +767,26 @@ def test_create_commit_create_pr(self):
766767
token=self._token,
767768
create_pr=True,
768769
)
770+
771+
# Check commit info
772+
self.assertIsInstance(resp, CommitInfo)
773+
commit_id = resp.oid
774+
self.assertIn("pr_revision='refs/pr/1'", str(resp))
775+
self.assertIsInstance(commit_id, str)
776+
self.assertGreater(len(commit_id), 0)
777+
self.assertEqual(
778+
resp.commit_url,
779+
f"{self._api.endpoint}/{USER}/{REPO_NAME}/commit/{commit_id}",
780+
)
781+
self.assertEqual(resp.commit_message, "Test create_commit")
782+
self.assertEqual(resp.commit_description, "")
769783
self.assertEqual(
770-
resp,
784+
resp.pr_url,
771785
f"{self._api.endpoint}/{USER}/{REPO_NAME}/discussions/1",
772786
)
787+
self.assertEqual(resp.pr_num, 1)
788+
self.assertEqual(resp.pr_revision, "refs/pr/1")
789+
773790
with self.assertRaises(HTTPError) as ctx:
774791
# Should raise a 404
775792
hf_hub_download(
@@ -830,13 +847,17 @@ def test_create_commit(self):
830847
path_or_fileobj=self.tmp_file,
831848
),
832849
]
833-
return_val = self._api.create_commit(
850+
resp = self._api.create_commit(
834851
operations=operations,
835852
commit_message="Test create_commit",
836853
repo_id=f"{USER}/{REPO_NAME}",
837854
token=self._token,
838855
)
839-
self.assertIsNone(return_val)
856+
# Check commit info
857+
self.assertIsInstance(resp, CommitInfo)
858+
self.assertIsNone(resp.pr_url) # No pr created
859+
self.assertIsNone(resp.pr_num)
860+
self.assertIsNone(resp.pr_revision)
840861
with self.assertRaises(HTTPError):
841862
# Should raise a 404
842863
hf_hub_download(

tests/test_init_lazy_loading.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,7 @@ def test_autocomplete_on_root_imports(self) -> None:
3131
self.assertTrue(
3232
signature_list[0]
3333
.docstring()
34-
.startswith("create_commit(self, repo_id: str")
34+
.startswith("create_commit(repo_id: str,")
3535
)
3636
break
3737
else:

0 commit comments

Comments
 (0)