|
1 | 1 | import json |
2 | 2 | import logging |
3 | 3 | import os |
| 4 | +from pathlib import Path |
4 | 5 | from typing import Dict, Optional, Union |
5 | 6 |
|
6 | 7 | import requests |
@@ -200,74 +201,104 @@ def from_pretrained( |
200 | 201 |
|
201 | 202 | return model |
202 | 203 |
|
203 | | - @staticmethod |
204 | 204 | def push_to_hub( |
205 | | - save_directory: Optional[str], |
206 | | - model_id: Optional[str] = None, |
| 205 | + self, |
| 206 | + repo_path_or_name: Optional[str] = None, |
207 | 207 | repo_url: Optional[str] = None, |
208 | | - commit_message: Optional[str] = "add model", |
| 208 | + commit_message: Optional[str] = "Add model", |
209 | 209 | organization: Optional[str] = None, |
210 | | - private: bool = None, |
211 | | - api_endpoint=None, |
212 | | - use_auth_token: Union[bool, str, None] = None, |
| 210 | + private: Optional[bool] = None, |
| 211 | + api_endpoint: Optional[str] = None, |
| 212 | + use_auth_token: Optional[Union[bool, str]] = None, |
213 | 213 | git_user: Optional[str] = None, |
214 | 214 | git_email: Optional[str] = None, |
| 215 | + config: Optional[dict] = None, |
215 | 216 | ) -> str: |
216 | 217 | """ |
| 218 | + Upload model checkpoint or tokenizer files to the 🤗 Model Hub while synchronizing a local clone of the repo in |
| 219 | + :obj:`repo_path_or_name`. |
| 220 | +
|
217 | 221 | Parameters: |
218 | | - save_directory (:obj:`Union[str, os.PathLike]`): |
219 | | - Directory having model weights & config. |
220 | | - model_id (:obj:`str`, `optional`, defaults to :obj:`save_directory`): |
221 | | - Repo name in huggingface_hub. If not specified, repo name will be same as `save_directory` |
| 222 | + repo_path_or_name (:obj:`str`, `optional`): |
| 223 | + Can either be a repository name for your model or tokenizer in the Hub or a path to a local folder (in |
| 224 | + which case the repository will have the name of that local folder). If not specified, will default to |
| 225 | + the name given by :obj:`repo_url` and a local directory with that name will be created. |
222 | 226 | repo_url (:obj:`str`, `optional`): |
223 | | - Specify this in case you want to push to existing repo in hub. |
| 227 | + Specify this in case you want to push to an existing repository in the hub. If unspecified, a new |
| 228 | + repository will be created in your namespace (unless you specify an :obj:`organization`) with |
| 229 | + :obj:`repo_name`. |
| 230 | + commit_message (:obj:`str`, `optional`): |
| 231 | + Message to commit while pushing. Will default to :obj:`"add config"`, :obj:`"add tokenizer"` or |
| 232 | + :obj:`"add model"` depending on the type of the class. |
224 | 233 | organization (:obj:`str`, `optional`): |
225 | | - Organization in which you want to push your model. |
| 234 | + Organization in which you want to push your model or tokenizer (you must be a member of this |
| 235 | + organization). |
226 | 236 | private (:obj:`bool`, `optional`): |
227 | | - private: Whether the model repo should be private (requires a paid huggingface.co account) |
228 | | - commit_message (:obj:`str`, `optional`, defaults to :obj:`add model`): |
229 | | - Message to commit while pushing |
| 237 | + Whether or not the repository created should be private (requires a paying subscription). |
230 | 238 | api_endpoint (:obj:`str`, `optional`): |
231 | 239 | The API endpoint to use when pushing the model to the hub. |
232 | | - use_auth_token (``str`` or ``bool``, `optional`, defaults ``None``): |
233 | | - huggingface_token can be extract from ``HfApi().login(username, password)`` and is used to authenticate |
234 | | - against the hub (useful from Google Colab for instance). |
235 | | - git_user (``str``, `optional`, defaults ``None``): |
| 240 | + use_auth_token (:obj:`bool` or :obj:`str`, `optional`): |
| 241 | + The token to use as HTTP bearer authorization for remote files. If :obj:`True`, will use the token |
| 242 | + generated when running :obj:`transformers-cli login` (stored in :obj:`~/.huggingface`). Will default to |
| 243 | + :obj:`True` if :obj:`repo_url` is not specified. |
| 244 | + git_user (``str``, `optional`): |
236 | 245 | will override the ``git config user.name`` for committing and pushing files to the hub. |
237 | | - git_email (``str``, `optional`, defaults ``None``): |
| 246 | + git_email (``str``, `optional`): |
238 | 247 | will override the ``git config user.email`` for committing and pushing files to the hub. |
| 248 | + config (:obj:`dict`, `optional`): |
| 249 | + Configuration object to be saved alongside the model weights. |
| 250 | +
|
239 | 251 |
|
240 | 252 | Returns: |
241 | | - url to commit on remote repo. |
| 253 | + The url of the commit of your model in the given repository. |
242 | 254 | """ |
243 | | - if model_id is None: |
244 | | - model_id = save_directory.split("/")[-1] |
245 | | - |
246 | | - # The auth token is necessary to create a repo |
247 | | - if isinstance(use_auth_token, str): |
248 | | - huggingface_token = use_auth_token |
249 | | - elif use_auth_token is None and repo_url is not None: |
250 | | - # If the repo url exists, then no need for a token |
251 | | - huggingface_token = None |
| 255 | + |
| 256 | + if repo_path_or_name is None and repo_url is None: |
| 257 | + raise ValueError( |
| 258 | + "You need to specify a `repo_path_or_name` or a `repo_url`." |
| 259 | + ) |
| 260 | + |
| 261 | + if use_auth_token is None and repo_url is None: |
| 262 | + token = HfFolder.get_token() |
| 263 | + if token is None: |
| 264 | + raise ValueError( |
| 265 | + "You must login to the Hugging Face hub on this computer by typing `transformers-cli login` and " |
| 266 | + "entering your credentials to use `use_auth_token=True`. Alternatively, you can pass your own " |
| 267 | + "token as the `use_auth_token` argument." |
| 268 | + ) |
| 269 | + elif isinstance(use_auth_token, str): |
| 270 | + token = use_auth_token |
252 | 271 | else: |
253 | | - huggingface_token = HfFolder.get_token() |
| 272 | + token = None |
| 273 | + |
| 274 | + if repo_path_or_name is None: |
| 275 | + repo_path_or_name = repo_url.split("/")[-1] |
254 | 276 |
|
255 | | - if repo_url is None: |
| 277 | + # If no URL is passed and there's no path to a directory containing files, create a repo |
| 278 | + if repo_url is None and not os.path.exists(repo_path_or_name): |
| 279 | + repo_name = Path(repo_path_or_name).name |
256 | 280 | repo_url = HfApi(endpoint=api_endpoint).create_repo( |
257 | | - huggingface_token, |
258 | | - model_id, |
| 281 | + token, |
| 282 | + repo_name, |
259 | 283 | organization=organization, |
260 | 284 | private=private, |
261 | 285 | repo_type=None, |
262 | 286 | exist_ok=True, |
263 | 287 | ) |
264 | 288 |
|
265 | 289 | repo = Repository( |
266 | | - save_directory, |
| 290 | + repo_path_or_name, |
267 | 291 | clone_from=repo_url, |
268 | 292 | use_auth_token=use_auth_token, |
269 | 293 | git_user=git_user, |
270 | 294 | git_email=git_email, |
271 | 295 | ) |
| 296 | + repo.git_pull(rebase=True) |
| 297 | + |
| 298 | + # Save the files in the cloned repo |
| 299 | + self.save_pretrained(repo_path_or_name, config=config) |
272 | 300 |
|
273 | | - return repo.push_to_hub(commit_message=commit_message) |
| 301 | + # Commit and push! |
| 302 | + repo.git_add() |
| 303 | + repo.git_commit(commit_message) |
| 304 | + return repo.git_push() |
0 commit comments