Skip to content

Commit 5ca3078

Browse files
LysandreJiksguggerjulien-c
authored
Repository power-up (#132)
* Repository power-up * Setup update * Setup update * Remove token for the test with remotes * Expose helpers * Slight reword * Add types * Re-iterate * Apply Julien's suggestion Co-authored-by: julien-c <[email protected]> * push_to_hub rework * Apply Julien's comments Co-authored-by: julien-c <[email protected]> Co-authored-by: Sylvain Gugger <[email protected]> Co-authored-by: julien-c <[email protected]>
1 parent dbea604 commit 5ca3078

File tree

4 files changed

+218
-115
lines changed

4 files changed

+218
-115
lines changed

src/huggingface_hub/hub_mixin.py

Lines changed: 68 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
import json
22
import logging
33
import os
4+
from pathlib import Path
45
from typing import Dict, Optional, Union
56

67
import requests
@@ -200,74 +201,104 @@ def from_pretrained(
200201

201202
return model
202203

203-
@staticmethod
204204
def push_to_hub(
205-
save_directory: Optional[str],
206-
model_id: Optional[str] = None,
205+
self,
206+
repo_path_or_name: Optional[str] = None,
207207
repo_url: Optional[str] = None,
208-
commit_message: Optional[str] = "add model",
208+
commit_message: Optional[str] = "Add model",
209209
organization: Optional[str] = None,
210-
private: bool = None,
211-
api_endpoint=None,
212-
use_auth_token: Union[bool, str, None] = None,
210+
private: Optional[bool] = None,
211+
api_endpoint: Optional[str] = None,
212+
use_auth_token: Optional[Union[bool, str]] = None,
213213
git_user: Optional[str] = None,
214214
git_email: Optional[str] = None,
215+
config: Optional[dict] = None,
215216
) -> str:
216217
"""
218+
Upload model checkpoint or tokenizer files to the 🤗 Model Hub while synchronizing a local clone of the repo in
219+
:obj:`repo_path_or_name`.
220+
217221
Parameters:
218-
save_directory (:obj:`Union[str, os.PathLike]`):
219-
Directory having model weights & config.
220-
model_id (:obj:`str`, `optional`, defaults to :obj:`save_directory`):
221-
Repo name in huggingface_hub. If not specified, repo name will be same as `save_directory`
222+
repo_path_or_name (:obj:`str`, `optional`):
223+
Can either be a repository name for your model or tokenizer in the Hub or a path to a local folder (in
224+
which case the repository will have the name of that local folder). If not specified, will default to
225+
the name given by :obj:`repo_url` and a local directory with that name will be created.
222226
repo_url (:obj:`str`, `optional`):
223-
Specify this in case you want to push to existing repo in hub.
227+
Specify this in case you want to push to an existing repository in the hub. If unspecified, a new
228+
repository will be created in your namespace (unless you specify an :obj:`organization`) with
229+
:obj:`repo_name`.
230+
commit_message (:obj:`str`, `optional`):
231+
Message to commit while pushing. Will default to :obj:`"add config"`, :obj:`"add tokenizer"` or
232+
:obj:`"add model"` depending on the type of the class.
224233
organization (:obj:`str`, `optional`):
225-
Organization in which you want to push your model.
234+
Organization in which you want to push your model or tokenizer (you must be a member of this
235+
organization).
226236
private (:obj:`bool`, `optional`):
227-
private: Whether the model repo should be private (requires a paid huggingface.co account)
228-
commit_message (:obj:`str`, `optional`, defaults to :obj:`add model`):
229-
Message to commit while pushing
237+
Whether or not the repository created should be private (requires a paying subscription).
230238
api_endpoint (:obj:`str`, `optional`):
231239
The API endpoint to use when pushing the model to the hub.
232-
use_auth_token (``str`` or ``bool``, `optional`, defaults ``None``):
233-
huggingface_token can be extract from ``HfApi().login(username, password)`` and is used to authenticate
234-
against the hub (useful from Google Colab for instance).
235-
git_user (``str``, `optional`, defaults ``None``):
240+
use_auth_token (:obj:`bool` or :obj:`str`, `optional`):
241+
The token to use as HTTP bearer authorization for remote files. If :obj:`True`, will use the token
242+
generated when running :obj:`transformers-cli login` (stored in :obj:`~/.huggingface`). Will default to
243+
:obj:`True` if :obj:`repo_url` is not specified.
244+
git_user (``str``, `optional`):
236245
will override the ``git config user.name`` for committing and pushing files to the hub.
237-
git_email (``str``, `optional`, defaults ``None``):
246+
git_email (``str``, `optional`):
238247
will override the ``git config user.email`` for committing and pushing files to the hub.
248+
config (:obj:`dict`, `optional`):
249+
Configuration object to be saved alongside the model weights.
250+
239251
240252
Returns:
241-
url to commit on remote repo.
253+
The url of the commit of your model in the given repository.
242254
"""
243-
if model_id is None:
244-
model_id = save_directory.split("/")[-1]
245-
246-
# The auth token is necessary to create a repo
247-
if isinstance(use_auth_token, str):
248-
huggingface_token = use_auth_token
249-
elif use_auth_token is None and repo_url is not None:
250-
# If the repo url exists, then no need for a token
251-
huggingface_token = None
255+
256+
if repo_path_or_name is None and repo_url is None:
257+
raise ValueError(
258+
"You need to specify a `repo_path_or_name` or a `repo_url`."
259+
)
260+
261+
if use_auth_token is None and repo_url is None:
262+
token = HfFolder.get_token()
263+
if token is None:
264+
raise ValueError(
265+
"You must login to the Hugging Face hub on this computer by typing `transformers-cli login` and "
266+
"entering your credentials to use `use_auth_token=True`. Alternatively, you can pass your own "
267+
"token as the `use_auth_token` argument."
268+
)
269+
elif isinstance(use_auth_token, str):
270+
token = use_auth_token
252271
else:
253-
huggingface_token = HfFolder.get_token()
272+
token = None
273+
274+
if repo_path_or_name is None:
275+
repo_path_or_name = repo_url.split("/")[-1]
254276

255-
if repo_url is None:
277+
# If no URL is passed and there's no path to a directory containing files, create a repo
278+
if repo_url is None and not os.path.exists(repo_path_or_name):
279+
repo_name = Path(repo_path_or_name).name
256280
repo_url = HfApi(endpoint=api_endpoint).create_repo(
257-
huggingface_token,
258-
model_id,
281+
token,
282+
repo_name,
259283
organization=organization,
260284
private=private,
261285
repo_type=None,
262286
exist_ok=True,
263287
)
264288

265289
repo = Repository(
266-
save_directory,
290+
repo_path_or_name,
267291
clone_from=repo_url,
268292
use_auth_token=use_auth_token,
269293
git_user=git_user,
270294
git_email=git_email,
271295
)
296+
repo.git_pull(rebase=True)
297+
298+
# Save the files in the cloned repo
299+
self.save_pretrained(repo_path_or_name, config=config)
272300

273-
return repo.push_to_hub(commit_message=commit_message)
301+
# Commit and push!
302+
repo.git_add()
303+
repo.git_commit(commit_message)
304+
return repo.git_push()

src/huggingface_hub/repository.py

Lines changed: 71 additions & 67 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
import os
33
import re
44
import subprocess
5+
from pathlib import Path
56
from typing import List, Optional, Union
67

78
from .hf_api import HfFolder
@@ -11,6 +12,38 @@
1112
logger = logging.getLogger(__name__)
1213

1314

15+
def is_git_repo(folder: Union[str, Path]):
16+
"""
17+
Check if the folder is the root of a git repository
18+
"""
19+
folder_exists = os.path.exists(os.path.join(folder, ".git"))
20+
git_branch = subprocess.run(
21+
"git branch".split(), cwd=folder, stdout=subprocess.PIPE, stderr=subprocess.PIPE
22+
)
23+
return folder_exists and git_branch.returncode == 0
24+
25+
26+
def is_local_clone(folder: Union[str, Path], remote_url: str):
27+
"""
28+
Check if the folder is the a local clone of the remote_url
29+
"""
30+
if not is_git_repo(folder):
31+
return False
32+
33+
remotes = subprocess.run(
34+
"git remote -v".split(),
35+
stderr=subprocess.PIPE,
36+
stdout=subprocess.PIPE,
37+
check=True,
38+
encoding="utf-8",
39+
cwd=folder,
40+
).stdout
41+
42+
# Remove token for the test with remotes.
43+
remote_url = re.sub(r"https://.*@", "https://", remote_url)
44+
return remote_url in remotes.split()
45+
46+
1447
class Repository:
1548
"""
1649
Helper class to wrap the git and git-lfs commands.
@@ -60,7 +93,7 @@ def __init__(
6093
if clone_from is not None:
6194
self.clone_from(repo_url=clone_from, use_auth_token=use_auth_token)
6295
else:
63-
if os.path.isdir(os.path.join(self.local_dir, ".git")):
96+
if is_git_repo(self.local_dir):
6497
logger.debug("[Repository] is a valid git repo")
6598
else:
6699
logger.error(
@@ -109,7 +142,9 @@ def check_git_versions(self):
109142

110143
def clone_from(self, repo_url: str, use_auth_token: Union[bool, str, None] = None):
111144
"""
112-
Clone from a remote.
145+
Clone from a remote. If the folder already exists, will try to clone the repository within it.
146+
147+
If this folder is a git repository with linked history, will try to update the repository.
113148
"""
114149
if isinstance(use_auth_token, str):
115150
huggingface_token = use_auth_token
@@ -139,6 +174,7 @@ def clone_from(self, repo_url: str, use_auth_token: Union[bool, str, None] = Non
139174

140175
# checks if repository is initialized in a empty repository or in one with files
141176
if len(os.listdir(self.local_dir)) == 0:
177+
logger.debug(f"Cloning {repo_url} into local empty directory.")
142178
subprocess.run(
143179
["git", "clone", repo_url, "."],
144180
stderr=subprocess.PIPE,
@@ -148,71 +184,39 @@ def clone_from(self, repo_url: str, use_auth_token: Union[bool, str, None] = Non
148184
cwd=self.local_dir,
149185
)
150186
else:
151-
logger.warning(
152-
"[Repository] local_dir is not empty, so let's try to pull the remote over a non-empty folder."
153-
)
154-
subprocess.run(
155-
"git init".split(),
156-
stderr=subprocess.PIPE,
157-
stdout=subprocess.PIPE,
158-
check=True,
159-
encoding="utf-8",
160-
cwd=self.local_dir,
161-
)
162-
163-
output = subprocess.run(
164-
"git remote -v".split(),
165-
stderr=subprocess.PIPE,
166-
stdout=subprocess.PIPE,
167-
check=True,
168-
encoding="utf-8",
169-
cwd=self.local_dir,
170-
)
171-
172-
if "origin" not in output.stdout.split():
173-
subprocess.run(
174-
["git", "remote", "add", "origin", repo_url],
175-
stderr=subprocess.PIPE,
176-
stdout=subprocess.PIPE,
177-
check=True,
178-
encoding="utf-8",
179-
cwd=self.local_dir,
180-
)
181-
182-
subprocess.run(
183-
"git fetch".split(),
184-
stderr=subprocess.PIPE,
185-
stdout=subprocess.PIPE,
186-
check=True,
187-
encoding="utf-8",
188-
cwd=self.local_dir,
189-
)
190-
subprocess.run(
191-
"git reset origin/main".split(),
192-
stderr=subprocess.PIPE,
193-
stdout=subprocess.PIPE,
194-
encoding="utf-8",
195-
check=True,
196-
cwd=self.local_dir,
197-
)
198-
199-
output = subprocess.run(
200-
"git branch".split(),
201-
stderr=subprocess.PIPE,
202-
stdout=subprocess.PIPE,
203-
check=True,
204-
encoding="utf-8",
205-
cwd=self.local_dir,
206-
)
207-
208-
if "main" not in output.stdout.split():
209-
subprocess.run(
210-
"git checkout origin/main -t".split(),
211-
stderr=subprocess.PIPE,
212-
stdout=subprocess.PIPE,
213-
encoding="utf-8",
214-
check=True,
215-
cwd=self.local_dir,
187+
# Check if the folder is the root of a git repository
188+
in_repository = is_git_repo(self.local_dir)
189+
190+
if in_repository:
191+
if is_local_clone(self.local_dir, repo_url):
192+
logger.debug(
193+
f"{self.local_dir} is already a clone of {repo_url}. Make sure you pull the latest"
194+
"changes with `repo.git_pull()`."
195+
)
196+
else:
197+
output = subprocess.run(
198+
"git remote get-url origin".split(),
199+
stderr=subprocess.PIPE,
200+
stdout=subprocess.PIPE,
201+
encoding="utf-8",
202+
cwd=self.local_dir,
203+
)
204+
205+
error_msg = (
206+
f"Tried to clone {repo_url} in an unrelated git repository.\nIf you believe this is an "
207+
f"error, please add a remote with the following URL: {repo_url}."
208+
)
209+
if output.returncode == 0:
210+
error_msg += f"\nLocal path has its origin defined as: {output.stdout}"
211+
212+
raise EnvironmentError(error_msg)
213+
214+
if not in_repository:
215+
raise EnvironmentError(
216+
"Tried to clone a repository in a non-empty folder that isn't a git repository. If you really "
217+
"want to do this, do it manually:\m"
218+
"git init && git remote add origin && git pull origin main\n"
219+
" or clone repo to a new folder and move your existing files there afterwards."
216220
)
217221

218222
except subprocess.CalledProcessError as exc:

tests/test_hubmixin.py

Lines changed: 2 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -113,18 +113,13 @@ def test_abs_path_from_pretrained(self):
113113

114114
def test_push_to_hub(self):
115115
model = DummyModel()
116-
model.save_pretrained(
117-
f"{WORKING_REPO_DIR}/{REPO_NAME}-PUSH_TO_HUB",
118-
config={"num": 7, "act": "gelu_fast"},
119-
)
120-
121116
model.push_to_hub(
122-
f"{WORKING_REPO_DIR}/{REPO_NAME}-PUSH_TO_HUB",
123-
f"{REPO_NAME}-PUSH_TO_HUB",
117+
repo_path_or_name=f"{WORKING_REPO_DIR}/{REPO_NAME}-PUSH_TO_HUB",
124118
api_endpoint=ENDPOINT_STAGING,
125119
use_auth_token=self._token,
126120
git_user="ci",
127121
git_email="[email protected]",
122+
config={"num": 7, "act": "gelu_fast"},
128123
)
129124

130125
model_info = self._api.model_info(

0 commit comments

Comments
 (0)