Skip to content

Commit 4b8384d

Browse files
julien-cphilschmid
andauthored
Repository helper class to make it easy to push programmatically
* Created HF Repository class for model hub (#7) * Created HF Repository class for model hub * fixed issues with pushing to hub from existing files * Update HfRepostiory.py * Update HfRepostiory.py * refactored HFRepository * removed model_card * fixed push to hub * make style * added comments * fixed remote user user * doc * rename this as not a lot (if any) is going to be HF-specific * black * Update doc with the philosophy of the class * work-in-progress (reorganize) * Still work-in-progress, but this is starting to take shape * Fix stuff Co-authored-by: Julien Chaumond <[email protected]> * Fix CI * [Python 3.6] remove capture_output=True * Oops * Fix CI Co-authored-by: Philipp Schmid <[email protected]>
1 parent 07cbb32 commit 4b8384d

File tree

7 files changed

+379
-9
lines changed

7 files changed

+379
-9
lines changed

src/huggingface_hub/commands/lfs.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -27,14 +27,12 @@
2727

2828
import requests
2929
from huggingface_hub.commands import BaseHuggingfaceCLICommand
30+
from huggingface_hub.lfs import LFS_MULTIPART_UPLOAD_COMMAND
3031

3132

3233
logger = logging.getLogger(__name__)
3334

3435

35-
LFS_MULTIPART_UPLOAD_COMMAND = "lfs-multipart-upload"
36-
37-
3836
class LfsCommands(BaseHuggingfaceCLICommand):
3937
"""
4038
Implementation of a custom transfer agent for the transfer type "multipart" for git-lfs. This lets users upload

src/huggingface_hub/hf_api.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -158,6 +158,9 @@ def create_repo(
158158
exist_ok: Do not raise an error if repo already exists
159159
160160
lfsmultipartthresh: Optional: internal param for testing purposes.
161+
162+
Returns:
163+
URL to the newly created repo.
161164
"""
162165
path = "{}/api/repos/create".format(self.endpoint)
163166

src/huggingface_hub/lfs.py

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,17 @@
1+
# coding=utf-8
2+
# Copyright 2019-present, the HuggingFace Inc. team.
3+
#
4+
# Licensed under the Apache License, Version 2.0 (the "License");
5+
# you may not use this file except in compliance with the License.
6+
# You may obtain a copy of the License at
7+
#
8+
# http://www.apache.org/licenses/LICENSE-2.0
9+
#
10+
# Unless required by applicable law or agreed to in writing, software
11+
# distributed under the License is distributed on an "AS IS" BASIS,
12+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
# See the License for the specific language governing permissions and
14+
# limitations under the License.
15+
16+
17+
LFS_MULTIPART_UPLOAD_COMMAND = "lfs-multipart-upload"

src/huggingface_hub/repository.py

Lines changed: 236 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,236 @@
1+
import logging
2+
import os
3+
import subprocess
4+
from typing import List, Optional, Union
5+
6+
from .hf_api import HfFolder
7+
from .lfs import LFS_MULTIPART_UPLOAD_COMMAND
8+
9+
10+
logger = logging.getLogger(__name__)
11+
12+
13+
class Repository:
14+
"""
15+
Helper class to wrap the git and git-lfs commands.
16+
17+
The aim is to facilitate interacting with huggingface.co hosted model or dataset repos,
18+
though not a lot here (if any) is actually specific to huggingface.co.
19+
"""
20+
21+
def __init__(
22+
self,
23+
local_dir: str,
24+
clone_from: Optional[str] = None,
25+
use_auth_token: Union[bool, str, None] = None,
26+
git_user: Optional[str] = None,
27+
git_email: Optional[str] = None,
28+
):
29+
"""
30+
Instantiate a local clone of a git repo.
31+
32+
If specifying a `clone_from`:
33+
will clone an existing remote repository
34+
that was previously created using ``HfApi().create_repo(token=huggingface_token, name=repo_name)``.
35+
``Repository`` uses the local git credentials by default, but if required, the ``huggingface_token``
36+
as well as the git ``user`` and the ``email`` can be specified.
37+
``Repository`` will then override them.
38+
If `clone_from` is used, and the repository is being instantiated into a non-empty directory,
39+
e.g. a directory with your trained model files, it will automatically merge them.
40+
41+
Args:
42+
local_dir (``str``):
43+
path (e.g. ``'my_trained_model/'``) to the local directory, where the ``Repository`` will be either initalized.
44+
clone_from (``str``, optional):
45+
repository url (e.g. ``'https://huggingface.co/philschmid/playground-tests'``).
46+
use_auth_token (``str`` or ``bool``, `optional`, defaults ``None``):
47+
huggingface_token can be extract from ``HfApi().login(username, password)`` and is used to authenticate against the hub.
48+
git_user (``str``, `optional`, defaults ``None``):
49+
will override the ``git config user.name`` for committing and pushing files to the hub.
50+
git_email (``str``, `optional`, defaults ``None``):
51+
will override the ``git config user.email`` for committing and pushing files to the hub.
52+
"""
53+
54+
os.makedirs(local_dir, exist_ok=True)
55+
self.local_dir = local_dir
56+
57+
self.check_git_versions()
58+
59+
if clone_from is not None:
60+
self.clone_from(repo_url=clone_from, use_auth_token=use_auth_token)
61+
else:
62+
try:
63+
remotes = subprocess.check_output(
64+
["git", "remote", "-v"],
65+
encoding="utf-8",
66+
cwd=self.local_dir,
67+
)
68+
logger.debug("[Repository] has remotes")
69+
logger.debug(remotes)
70+
except subprocess.CalledProcessError:
71+
logger.error(
72+
"If not specifying `clone_from`, you need to pass Repository a valid git clone."
73+
)
74+
raise ValueError(
75+
"If not specifying `clone_from`, you need to pass Repository a valid git clone."
76+
)
77+
78+
# overrides .git config if user and email is provided.
79+
if git_user is not None or git_email is not None:
80+
self.git_config_username_and_email(git_user, git_email)
81+
82+
def check_git_versions(self):
83+
"""
84+
print git and git-lfs versions, raises if they aren't installed.
85+
"""
86+
try:
87+
git_version = subprocess.check_output(
88+
["git", "--version"], encoding="utf-8"
89+
).strip()
90+
except FileNotFoundError:
91+
raise EnvironmentError(
92+
"Looks like you do not have git installed, please install."
93+
)
94+
95+
try:
96+
lfs_version = subprocess.check_output(
97+
["git-lfs", "--version"],
98+
encoding="utf-8",
99+
).strip()
100+
except FileNotFoundError:
101+
raise EnvironmentError(
102+
"Looks like you do not have git-lfs installed, please install."
103+
" You can install from https://git-lfs.github.com/."
104+
" Then run `git lfs install` (you only have to do this once)."
105+
)
106+
logger.info(git_version + "\n" + lfs_version)
107+
108+
def clone_from(self, repo_url: str, use_auth_token: Union[bool, str, None] = None):
109+
"""
110+
Clone from a remote.
111+
"""
112+
if isinstance(use_auth_token, str):
113+
huggingface_token = use_auth_token
114+
elif use_auth_token:
115+
huggingface_token = HfFolder.get_token()
116+
else:
117+
huggingface_token = None
118+
119+
if (
120+
huggingface_token is not None
121+
and "huggingface.co" in repo_url
122+
and "@" not in repo_url
123+
):
124+
# adds huggingface_token to repo url if it is provided.
125+
# do not leak user token if it's not a repo on hf.co
126+
repo_url = repo_url.replace(
127+
"https://", f"https://user:{huggingface_token}@"
128+
)
129+
130+
subprocess.run("git lfs install".split(), check=True)
131+
132+
# checks if repository is initialized in a empty repository or in one with files
133+
if len(os.listdir(self.local_dir)) == 0:
134+
subprocess.run(
135+
["git", "clone", repo_url, "."], check=True, cwd=self.local_dir
136+
)
137+
else:
138+
logger.warning(
139+
"[Repository] local_dir is not empty, so let's try to pull the remote over a non-empty folder."
140+
)
141+
subprocess.run("git init".split(), check=True, cwd=self.local_dir)
142+
subprocess.run(
143+
["git", "remote", "add", "origin", repo_url],
144+
check=True,
145+
cwd=self.local_dir,
146+
)
147+
subprocess.run("git fetch".split(), check=True, cwd=self.local_dir)
148+
subprocess.run(
149+
"git reset origin/main".split(), check=True, cwd=self.local_dir
150+
)
151+
# TODO(check if we really want the --force flag)
152+
subprocess.run(
153+
"git checkout origin/main -ft".split(), check=True, cwd=self.local_dir
154+
)
155+
156+
def git_config_username_and_email(
157+
self, git_user: Optional[str] = None, git_email: Optional[str] = None
158+
):
159+
"""
160+
sets git user name and email (only in the current repo)
161+
"""
162+
if git_user is not None:
163+
subprocess.run(
164+
f"git config user.name {git_user}".split(),
165+
check=True,
166+
cwd=self.local_dir,
167+
)
168+
if git_email is not None:
169+
subprocess.run(
170+
f"git config user.email {git_email}".split(),
171+
check=True,
172+
cwd=self.local_dir,
173+
)
174+
175+
def lfs_track(self, patterns: List[str]):
176+
"""
177+
Tell git-lfs to track those files.
178+
"""
179+
for pattern in patterns:
180+
subprocess.run(
181+
["git", "lfs", "track", pattern], check=True, cwd=self.local_dir
182+
)
183+
184+
def lfs_enable_largefiles(self):
185+
"""
186+
HF-specific. This enables upload support of files >5GB.
187+
"""
188+
subprocess.run(
189+
"git config lfs.customtransfer.multipart.path huggingface-cli".split(),
190+
check=True,
191+
cwd=self.local_dir,
192+
)
193+
subprocess.run(
194+
f"git config lfs.customtransfer.multipart.args {LFS_MULTIPART_UPLOAD_COMMAND}".split(),
195+
check=True,
196+
cwd=self.local_dir,
197+
)
198+
199+
def git_pull(self, rebase: Optional[bool] = False):
200+
"""
201+
git pull
202+
"""
203+
args = "git pull".split()
204+
if rebase:
205+
args.append("--rebase")
206+
subprocess.run(args, check=True, cwd=self.local_dir)
207+
208+
def git_add(self, pattern="."):
209+
"""
210+
git add
211+
"""
212+
subprocess.run("git add .".split(), check=True, cwd=self.local_dir)
213+
214+
def git_commit(self, commit_message="commit files to HF hub"):
215+
"""
216+
git commit
217+
"""
218+
subprocess.run(
219+
["git", "commit", "-m", commit_message], check=True, cwd=self.local_dir
220+
)
221+
222+
def git_push(self):
223+
"""
224+
git push
225+
"""
226+
subprocess.run("git push".split(), check=True, cwd=self.local_dir)
227+
228+
def push_to_hub(self, commit_message="commit files to HF hub"):
229+
"""
230+
Helper to add, commit, and pushe file to remote repository on the HuggingFace Hub.
231+
Args:
232+
commit_message: commit message.
233+
"""
234+
self.git_add()
235+
self.git_commit(commit_message)
236+
self.git_push()

tests/test_hf_api.py

Lines changed: 1 addition & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -23,15 +23,10 @@
2323
from huggingface_hub.hf_api import HfApi, HfFolder, ModelInfo, RepoObj
2424
from requests.exceptions import HTTPError
2525

26+
from .testing_constants import ENDPOINT_STAGING, ENDPOINT_STAGING_BASIC_AUTH, PASS, USER
2627
from .testing_utils import require_git_lfs
2728

2829

29-
USER = "__DUMMY_TRANSFORMERS_USER__"
30-
PASS = "__DUMMY_TRANSFORMERS_PASS__"
31-
32-
ENDPOINT_STAGING = "https://moon-staging.huggingface.co"
33-
ENDPOINT_STAGING_BASIC_AUTH = f"https://{USER}:{PASS}@moon-staging.huggingface.co"
34-
3530
REPO_NAME = "my-model-{}".format(int(time.time() * 10e3))
3631
REPO_NAME_LARGE_FILE = "my-model-largefiles-{}".format(int(time.time() * 10e3))
3732
DATASET_REPO_NAME = "my-dataset-{}".format(int(time.time() * 10e3))

0 commit comments

Comments
 (0)