|
2 | 2 | import os |
3 | 3 | import re |
4 | 4 | import subprocess |
| 5 | +from contextlib import contextmanager |
5 | 6 | from pathlib import Path |
6 | 7 | from typing import List, Optional, Union |
7 | 8 |
|
8 | | -from .hf_api import HfFolder |
| 9 | +from .hf_api import ENDPOINT, HfApi, HfFolder |
9 | 10 | from .lfs import LFS_MULTIPART_UPLOAD_COMMAND |
10 | 11 |
|
11 | 12 |
|
@@ -87,12 +88,23 @@ def __init__( |
87 | 88 | """ |
88 | 89 |
|
89 | 90 | os.makedirs(local_dir, exist_ok=True) |
90 | | - self.local_dir = local_dir |
| 91 | + self.local_dir = os.path.join(os.getcwd(), local_dir) |
91 | 92 |
|
92 | 93 | self.check_git_versions() |
93 | 94 |
|
| 95 | + if isinstance(use_auth_token, str): |
| 96 | + self.huggingface_token = use_auth_token |
| 97 | + elif use_auth_token: |
| 98 | + self.huggingface_token = HfFolder.get_token() |
| 99 | + else: |
| 100 | + self.huggingface_token = None |
| 101 | + |
94 | 102 | if clone_from is not None: |
95 | | - self.clone_from(repo_url=clone_from, use_auth_token=use_auth_token) |
| 103 | + |
| 104 | + if "http" not in clone_from: |
| 105 | + clone_from = f"{ENDPOINT}/{clone_from}" |
| 106 | + |
| 107 | + self.clone_from(repo_url=clone_from) |
96 | 108 | else: |
97 | 109 | if is_git_repo(self.local_dir): |
98 | 110 | logger.debug("[Repository] is a valid git repo") |
@@ -147,24 +159,21 @@ def clone_from(self, repo_url: str, use_auth_token: Union[bool, str, None] = Non |
147 | 159 |
|
148 | 160 | If this folder is a git repository with linked history, will try to update the repository. |
149 | 161 | """ |
150 | | - if isinstance(use_auth_token, str): |
151 | | - huggingface_token = use_auth_token |
152 | | - elif use_auth_token: |
153 | | - huggingface_token = HfFolder.get_token() |
154 | | - else: |
155 | | - huggingface_token = None |
156 | | - |
157 | | - if ( |
158 | | - huggingface_token is not None |
159 | | - and "huggingface.co" in repo_url |
160 | | - and "@" not in repo_url |
161 | | - ): |
| 162 | + token = use_auth_token if use_auth_token is not None else self.huggingface_token |
| 163 | + if token is not None and "huggingface.co" in repo_url and "@" not in repo_url: |
| 164 | + endpoint = "/".join(repo_url.split("/")[:-2]) |
162 | 165 | # adds huggingface_token to repo url if it is provided. |
163 | 166 | # do not leak user token if it's not a repo on hf.co |
164 | | - repo_url = repo_url.replace( |
165 | | - "https://", f"https://user:{huggingface_token}@" |
166 | | - ) |
| 167 | + repo_url = repo_url.replace("https://", f"https://user:{token}@") |
| 168 | + |
| 169 | + organization, repo_id = repo_url.split("/")[-2:] |
167 | 170 |
|
| 171 | + HfApi(endpoint=endpoint).create_repo( |
| 172 | + token, |
| 173 | + repo_id, |
| 174 | + organization=organization, |
| 175 | + exist_ok=True, |
| 176 | + ) |
168 | 177 | # For error messages, it's cleaner to show the repo url without the token. |
169 | 178 | clean_repo_url = re.sub(r"https://.*@", "https://", repo_url) |
170 | 179 | try: |
@@ -432,3 +441,54 @@ def push_to_hub(self, commit_message="commit files to HF hub") -> str: |
432 | 441 | self.git_add() |
433 | 442 | self.git_commit(commit_message) |
434 | 443 | return self.git_push() |
| 444 | + |
| 445 | + @contextmanager |
| 446 | + def commit( |
| 447 | + self, |
| 448 | + commit_message: str, |
| 449 | + ): |
| 450 | + """ |
| 451 | + Context manager utility to handle committing to a repository. |
| 452 | +
|
| 453 | + Examples: |
| 454 | +
|
| 455 | + >>> with Repository("text-files", clone_from="<user>/text-files", use_auth_token=True).commit("My first file :)"): |
| 456 | + ... with open("file.txt", "w+") as f: |
| 457 | + ... f.write(json.dumps({"hey": 8})) |
| 458 | +
|
| 459 | + >>> import torch |
| 460 | + >>> model = torch.nn.Transformer() |
| 461 | + >>> with Repository("torch-model", clone_from="<user>/torch-model", use_auth_token=True).commit("My cool model :)"): |
| 462 | + ... torch.save(model.state_dict(), "model.pt") |
| 463 | +
|
| 464 | + """ |
| 465 | + |
| 466 | + self.git_pull(rebase=True) |
| 467 | + |
| 468 | + current_working_directory = os.getcwd() |
| 469 | + os.chdir(os.path.join(current_working_directory, self.local_dir)) |
| 470 | + |
| 471 | + try: |
| 472 | + yield self |
| 473 | + finally: |
| 474 | + self.git_add() |
| 475 | + |
| 476 | + try: |
| 477 | + self.git_commit(commit_message) |
| 478 | + except OSError as e: |
| 479 | + # If no changes are detected, there is nothing to commit. |
| 480 | + if "nothing to commit" not in str(e): |
| 481 | + raise e |
| 482 | + |
| 483 | + try: |
| 484 | + self.git_push() |
| 485 | + except OSError as e: |
| 486 | + # If no changes are detected, there is nothing to commit. |
| 487 | + if "could not read Username" in str(e): |
| 488 | + raise OSError( |
| 489 | + "Couldn't authenticate user for push. Did you set `use_auth_token` to `True`?" |
| 490 | + ) from e |
| 491 | + else: |
| 492 | + raise e |
| 493 | + |
| 494 | + os.chdir(current_working_directory) |
0 commit comments