Skip to content

Commit b85ab03

Browse files
authored
Context manager (#143)
* Context Manager * Auth * Use Repository().commit instead * Remove from init * Better error message * Cleanup * Add example
1 parent a8493df commit b85ab03

File tree

2 files changed

+140
-18
lines changed

2 files changed

+140
-18
lines changed

src/huggingface_hub/repository.py

Lines changed: 78 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -2,10 +2,11 @@
22
import os
33
import re
44
import subprocess
5+
from contextlib import contextmanager
56
from pathlib import Path
67
from typing import List, Optional, Union
78

8-
from .hf_api import HfFolder
9+
from .hf_api import ENDPOINT, HfApi, HfFolder
910
from .lfs import LFS_MULTIPART_UPLOAD_COMMAND
1011

1112

@@ -87,12 +88,23 @@ def __init__(
8788
"""
8889

8990
os.makedirs(local_dir, exist_ok=True)
90-
self.local_dir = local_dir
91+
self.local_dir = os.path.join(os.getcwd(), local_dir)
9192

9293
self.check_git_versions()
9394

95+
if isinstance(use_auth_token, str):
96+
self.huggingface_token = use_auth_token
97+
elif use_auth_token:
98+
self.huggingface_token = HfFolder.get_token()
99+
else:
100+
self.huggingface_token = None
101+
94102
if clone_from is not None:
95-
self.clone_from(repo_url=clone_from, use_auth_token=use_auth_token)
103+
104+
if "http" not in clone_from:
105+
clone_from = f"{ENDPOINT}/{clone_from}"
106+
107+
self.clone_from(repo_url=clone_from)
96108
else:
97109
if is_git_repo(self.local_dir):
98110
logger.debug("[Repository] is a valid git repo")
@@ -147,24 +159,21 @@ def clone_from(self, repo_url: str, use_auth_token: Union[bool, str, None] = Non
147159
148160
If this folder is a git repository with linked history, will try to update the repository.
149161
"""
150-
if isinstance(use_auth_token, str):
151-
huggingface_token = use_auth_token
152-
elif use_auth_token:
153-
huggingface_token = HfFolder.get_token()
154-
else:
155-
huggingface_token = None
156-
157-
if (
158-
huggingface_token is not None
159-
and "huggingface.co" in repo_url
160-
and "@" not in repo_url
161-
):
162+
token = use_auth_token if use_auth_token is not None else self.huggingface_token
163+
if token is not None and "huggingface.co" in repo_url and "@" not in repo_url:
164+
endpoint = "/".join(repo_url.split("/")[:-2])
162165
# adds huggingface_token to repo url if it is provided.
163166
# do not leak user token if it's not a repo on hf.co
164-
repo_url = repo_url.replace(
165-
"https://", f"https://user:{huggingface_token}@"
166-
)
167+
repo_url = repo_url.replace("https://", f"https://user:{token}@")
168+
169+
organization, repo_id = repo_url.split("/")[-2:]
167170

171+
HfApi(endpoint=endpoint).create_repo(
172+
token,
173+
repo_id,
174+
organization=organization,
175+
exist_ok=True,
176+
)
168177
# For error messages, it's cleaner to show the repo url without the token.
169178
clean_repo_url = re.sub(r"https://.*@", "https://", repo_url)
170179
try:
@@ -432,3 +441,54 @@ def push_to_hub(self, commit_message="commit files to HF hub") -> str:
432441
self.git_add()
433442
self.git_commit(commit_message)
434443
return self.git_push()
444+
445+
@contextmanager
446+
def commit(
447+
self,
448+
commit_message: str,
449+
):
450+
"""
451+
Context manager utility to handle committing to a repository.
452+
453+
Examples:
454+
455+
>>> with Repository("text-files", clone_from="<user>/text-files", use_auth_token=True).commit("My first file :)"):
456+
... with open("file.txt", "w+") as f:
457+
... f.write(json.dumps({"hey": 8}))
458+
459+
>>> import torch
460+
>>> model = torch.nn.Transformer()
461+
>>> with Repository("torch-model", clone_from="<user>/torch-model", use_auth_token=True).commit("My cool model :)"):
462+
... torch.save(model.state_dict(), "model.pt")
463+
464+
"""
465+
466+
self.git_pull(rebase=True)
467+
468+
current_working_directory = os.getcwd()
469+
os.chdir(os.path.join(current_working_directory, self.local_dir))
470+
471+
try:
472+
yield self
473+
finally:
474+
self.git_add()
475+
476+
try:
477+
self.git_commit(commit_message)
478+
except OSError as e:
479+
# If no changes are detected, there is nothing to commit.
480+
if "nothing to commit" not in str(e):
481+
raise e
482+
483+
try:
484+
self.git_push()
485+
except OSError as e:
486+
# If no changes are detected, there is nothing to commit.
487+
if "could not read Username" in str(e):
488+
raise OSError(
489+
"Couldn't authenticate user for push. Did you set `use_auth_token` to `True`?"
490+
) from e
491+
else:
492+
raise e
493+
494+
os.chdir(current_working_directory)

tests/test_repository.py

Lines changed: 62 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -201,3 +201,65 @@ def test_add_commit_push(self):
201201
# actually exists.
202202
r = requests.head(url)
203203
r.raise_for_status()
204+
205+
shutil.rmtree(WORKING_REPO_DIR)
206+
207+
def test_clone_with_repo_name_and_organization(self):
208+
clone = Repository(
209+
REPO_NAME,
210+
clone_from=f"{ENDPOINT_STAGING}/valid_org/{REPO_NAME}",
211+
use_auth_token=self._token,
212+
git_user="ci",
213+
git_email="[email protected]",
214+
)
215+
216+
with clone.commit("Commit"):
217+
with open("dummy.txt", "w") as f:
218+
f.write("hello")
219+
with open("model.bin", "w") as f:
220+
f.write("hello")
221+
222+
shutil.rmtree(REPO_NAME)
223+
224+
Repository(
225+
f"{WORKING_REPO_DIR}/{REPO_NAME}",
226+
clone_from=f"{ENDPOINT_STAGING}/valid_org/{REPO_NAME}",
227+
use_auth_token=self._token,
228+
git_user="ci",
229+
git_email="[email protected]",
230+
)
231+
232+
files = os.listdir(f"{WORKING_REPO_DIR}/{REPO_NAME}")
233+
self.assertTrue("dummy.txt" in files)
234+
self.assertTrue("model.bin" in files)
235+
236+
def test_clone_with_repo_name_and_user_namespace(self):
237+
clone = Repository(
238+
REPO_NAME,
239+
clone_from=f"{ENDPOINT_STAGING}/{USER}/{REPO_NAME}",
240+
use_auth_token=self._token,
241+
git_user="ci",
242+
git_email="[email protected]",
243+
)
244+
245+
with clone.commit("Commit"):
246+
# Create dummy files
247+
# one is lfs-tracked, the other is not.
248+
with open("dummy.txt", "w") as f:
249+
f.write("hello")
250+
with open("model.bin", "w") as f:
251+
f.write("hello")
252+
253+
shutil.rmtree(REPO_NAME)
254+
255+
Repository(
256+
f"{WORKING_REPO_DIR}/{REPO_NAME}",
257+
clone_from=f"{ENDPOINT_STAGING}/{USER}/{REPO_NAME}",
258+
use_auth_token=self._token,
259+
git_user="ci",
260+
git_email="[email protected]",
261+
)
262+
263+
files = os.listdir(f"{WORKING_REPO_DIR}/{REPO_NAME}")
264+
self.assertTrue("dummy.txt" in files)
265+
self.assertTrue("model.bin" in files)

0 commit comments

Comments
 (0)