Skip to content

Commit 9e1d3d5

Browse files
Adding mixin class for ease saving, uploading, downloading (as discussed in issue #9). (#11)
* work initiated * start upload_to_hub * add changes * final-push * i feel this is better. * updated for Repositary class * small updates * fix mutiple calling * small fix * make style * add everything * minor fix * minor fix * done evrything * small fix * [doc] remove mention of TF support * Fix typings (i think) * We do NOT want to have a hard requirement on torch * Fix flake8 * Fix CI Co-authored-by: Julien Chaumond <[email protected]>
1 parent be902d8 commit 9e1d3d5

File tree

4 files changed

+343
-6
lines changed

4 files changed

+343
-6
lines changed

src/huggingface_hub/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,4 +30,5 @@
3030
)
3131
from .file_download import cached_download, hf_hub_url
3232
from .hf_api import HfApi, HfFolder
33+
from .hub_mixin import ModelHubMixin
3334
from .repository import Repository

src/huggingface_hub/hub_mixin.py

Lines changed: 246 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,246 @@
1+
import json
2+
import logging
3+
import os
4+
from typing import Dict, Optional
5+
6+
import requests
7+
8+
from .constants import CONFIG_NAME, PYTORCH_WEIGHTS_NAME
9+
from .file_download import cached_download, hf_hub_url, is_torch_available
10+
from .hf_api import HfApi, HfFolder
11+
from .repository import Repository
12+
13+
14+
if is_torch_available():
15+
import torch
16+
17+
18+
logger = logging.getLogger(__name__)
19+
20+
21+
class ModelHubMixin(object):
22+
def __init__(self, *args, **kwargs):
23+
"""
24+
Mix this class with your torch-model class for ease process of saving & loading from huggingface-hub
25+
26+
Example::
27+
28+
>>> from huggingface_hub import ModelHubMixin
29+
30+
>>> class MyModel(nn.Module, ModelHubMixin):
31+
... def __init__(self, **kwargs):
32+
... super().__init__()
33+
... self.config = kwargs.pop("config", None)
34+
... self.layer = ...
35+
... def forward(self, ...)
36+
... return ...
37+
38+
>>> model = MyModel()
39+
>>> model.save_pretrained("mymodel", push_to_hub=False) # Saving model weights in the directory
40+
>>> model.push_to_hub("mymodel", "model-1") # Pushing model-weights to hf-hub
41+
42+
>>> # Downloading weights from hf-hub & model will be initialized from those weights
43+
>>> model = MyModel.from_pretrained("username/mymodel@main")
44+
"""
45+
46+
def save_pretrained(
47+
self,
48+
save_directory: str,
49+
config: Optional[dict] = None,
50+
push_to_hub: bool = False,
51+
**kwargs,
52+
):
53+
"""
54+
Saving weights in local directory.
55+
56+
Parameters:
57+
save_directory (:obj:`str`):
58+
Specify directory in which you want to save weights.
59+
config (:obj:`dict`, `optional`):
60+
specify config (must be dict) incase you want to save it.
61+
push_to_hub (:obj:`bool`, `optional`, defaults to :obj:`False`):
62+
Set it to `True` in case you want to push your weights to huggingface_hub
63+
model_id (:obj:`str`, `optional`, defaults to :obj:`save_directory`):
64+
Repo name in huggingface_hub. If not specified, repo name will be same as `save_directory`
65+
kwargs (:obj:`Dict`, `optional`):
66+
kwargs will be passed to `push_to_hub`
67+
"""
68+
69+
os.makedirs(save_directory, exist_ok=True)
70+
71+
# saving config
72+
if isinstance(config, dict):
73+
path = os.path.join(save_directory, CONFIG_NAME)
74+
with open(path, "w") as f:
75+
json.dump(config, f)
76+
77+
# saving model weights
78+
path = os.path.join(save_directory, PYTORCH_WEIGHTS_NAME)
79+
self._save_pretrained(path)
80+
81+
if push_to_hub:
82+
return self.push_to_hub(save_directory, **kwargs)
83+
84+
def _save_pretrained(self, path):
85+
"""
86+
Overwrite this method in case you don't want to save complete model, rather some specific layers
87+
"""
88+
model_to_save = self.module if hasattr(self, "module") else self
89+
torch.save(model_to_save.state_dict(), path)
90+
91+
@classmethod
92+
def from_pretrained(
93+
cls,
94+
pretrained_model_name_or_path: Optional[str],
95+
strict: bool = True,
96+
map_location: Optional[str] = "cpu",
97+
force_download: bool = False,
98+
resume_download: bool = False,
99+
proxies: Dict = None,
100+
use_auth_token: Optional[str] = None,
101+
cache_dir: Optional[str] = None,
102+
local_files_only: bool = False,
103+
**model_kwargs,
104+
):
105+
r"""
106+
Instantiate a pretrained pytorch model from a pre-trained model configuration from huggingface-hub.
107+
The model is set in evaluation mode by default using ``model.eval()`` (Dropout modules are deactivated). To
108+
train the model, you should first set it back in training mode with ``model.train()``.
109+
110+
Parameters:
111+
pretrained_model_name_or_path (:obj:`str` or :obj:`os.PathLike`, `optional`):
112+
Can be either:
113+
- A string, the `model id` of a pretrained model hosted inside a model repo on huggingface.co.
114+
Valid model ids can be located at the root-level, like ``bert-base-uncased``, or namespaced under
115+
a user or organization name, like ``dbmdz/bert-base-german-cased``.
116+
- You can add `revision` by appending `@` at the end of model_id simply like this: ``dbmdz/bert-base-german-cased@main``
117+
Revision is the specific model version to use. It can be a branch name, a tag name, or a commit id,
118+
since we use a git-based system for storing models and other artifacts on huggingface.co, so ``revision`` can be any identifier allowed by git.
119+
- A path to a `directory` containing model weights saved using
120+
:func:`~transformers.PreTrainedModel.save_pretrained`, e.g., ``./my_model_directory/``.
121+
- :obj:`None` if you are both providing the configuration and state dictionary (resp. with keyword
122+
arguments ``config`` and ``state_dict``).
123+
cache_dir (:obj:`Union[str, os.PathLike]`, `optional`):
124+
Path to a directory in which a downloaded pretrained model configuration should be cached if the
125+
standard cache should not be used.
126+
force_download (:obj:`bool`, `optional`, defaults to :obj:`False`):
127+
Whether or not to force the (re-)download of the model weights and configuration files, overriding the
128+
cached versions if they exist.
129+
resume_download (:obj:`bool`, `optional`, defaults to :obj:`False`):
130+
Whether or not to delete incompletely received files. Will attempt to resume the download if such a
131+
file exists.
132+
proxies (:obj:`Dict[str, str], `optional`):
133+
A dictionary of proxy servers to use by protocol or endpoint, e.g., :obj:`{'http': 'foo.bar:3128',
134+
'http://hostname': 'foo.bar:4012'}`. The proxies are used on each request.
135+
local_files_only(:obj:`bool`, `optional`, defaults to :obj:`False`):
136+
Whether or not to only look at local files (i.e., do not try to download the model).
137+
use_auth_token (:obj:`str` or `bool`, `optional`):
138+
The token to use as HTTP bearer authorization for remote files. If :obj:`True`, will use the token
139+
generated when running :obj:`transformers-cli login` (stored in :obj:`~/.huggingface`).
140+
model_kwargs (:obj:`Dict`, `optional`)::
141+
model_kwargs will be passed to the model during initialization
142+
.. note::
143+
Passing :obj:`use_auth_token=True` is required when you want to use a private model.
144+
"""
145+
146+
model_id = pretrained_model_name_or_path
147+
map_location = torch.device(map_location)
148+
149+
revision = None
150+
if len(model_id.split("@")) == 2:
151+
model_id, revision = model_id.split("@")
152+
153+
if model_id in os.listdir() and CONFIG_NAME in os.listdir(model_id):
154+
config_file = os.path.join(model_id, CONFIG_NAME)
155+
else:
156+
try:
157+
config_url = hf_hub_url(
158+
model_id, filename=CONFIG_NAME, revision=revision
159+
)
160+
config_file = cached_download(
161+
config_url,
162+
cache_dir=cache_dir,
163+
force_download=force_download,
164+
proxies=proxies,
165+
resume_download=resume_download,
166+
local_files_only=local_files_only,
167+
use_auth_token=use_auth_token,
168+
)
169+
except requests.exceptions.RequestException:
170+
logger.warning("config.json NOT FOUND in HuggingFace Hub")
171+
config_file = None
172+
173+
if model_id in os.listdir():
174+
print("LOADING weights from local directory")
175+
model_file = os.path.join(model_id, PYTORCH_WEIGHTS_NAME)
176+
else:
177+
model_url = hf_hub_url(
178+
model_id, filename=PYTORCH_WEIGHTS_NAME, revision=revision
179+
)
180+
model_file = cached_download(
181+
model_url,
182+
cache_dir=cache_dir,
183+
force_download=force_download,
184+
proxies=proxies,
185+
resume_download=resume_download,
186+
local_files_only=local_files_only,
187+
use_auth_token=use_auth_token,
188+
)
189+
190+
if config_file is not None:
191+
with open(config_file, "r", encoding="utf-8") as f:
192+
config = json.load(f)
193+
model_kwargs.update({"config": config})
194+
195+
model = cls(**model_kwargs)
196+
197+
state_dict = torch.load(model_file, map_location=map_location)
198+
model.load_state_dict(state_dict, strict=strict)
199+
model.eval()
200+
201+
return model
202+
203+
@staticmethod
204+
def push_to_hub(
205+
save_directory: Optional[str],
206+
model_id: Optional[str] = None,
207+
repo_url: Optional[str] = None,
208+
commit_message: Optional[str] = "add model",
209+
organization: Optional[str] = None,
210+
private: bool = None,
211+
) -> str:
212+
"""
213+
Parameters:
214+
save_directory (:obj:`Union[str, os.PathLike]`):
215+
Directory having model weights & config.
216+
model_id (:obj:`str`, `optional`, defaults to :obj:`save_directory`):
217+
Repo name in huggingface_hub. If not specified, repo name will be same as `save_directory`
218+
repo_url (:obj:`str`, `optional`):
219+
Specify this in case you want to push to existing repo in hub.
220+
organization (:obj:`str`, `optional`):
221+
Organization in which you want to push your model.
222+
private (:obj:`bool`, `optional`):
223+
private: Whether the model repo should be private (requires a paid huggingface.co account)
224+
commit_message (:obj:`str`, `optional`, defaults to :obj:`add model`):
225+
Message to commit while pushing
226+
227+
Returns:
228+
url to commit on remote repo.
229+
"""
230+
if model_id is None:
231+
model_id = save_directory
232+
233+
token = HfFolder.get_token()
234+
if repo_url is None:
235+
repo_url = HfApi().create_repo(
236+
token,
237+
model_id,
238+
organization=organization,
239+
private=private,
240+
repo_type=None,
241+
exist_ok=True,
242+
)
243+
244+
repo = Repository(save_directory, clone_from=repo_url, use_auth_token=token)
245+
246+
return repo.push_to_hub(commit_message=commit_message)

src/huggingface_hub/repository.py

Lines changed: 30 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -159,14 +159,26 @@ def clone_from(self, repo_url: str, use_auth_token: Union[bool, str, None] = Non
159159
encoding="utf-8",
160160
cwd=self.local_dir,
161161
)
162-
subprocess.run(
163-
["git", "remote", "add", "origin", repo_url],
162+
163+
output = subprocess.run(
164+
"git remote -v".split(),
164165
stderr=subprocess.PIPE,
165166
stdout=subprocess.PIPE,
166167
check=True,
167168
encoding="utf-8",
168169
cwd=self.local_dir,
169170
)
171+
172+
if "origin" not in output.stdout.split():
173+
subprocess.run(
174+
["git", "remote", "add", "origin", repo_url],
175+
stderr=subprocess.PIPE,
176+
stdout=subprocess.PIPE,
177+
check=True,
178+
encoding="utf-8",
179+
cwd=self.local_dir,
180+
)
181+
170182
subprocess.run(
171183
"git fetch".split(),
172184
stderr=subprocess.PIPE,
@@ -183,15 +195,27 @@ def clone_from(self, repo_url: str, use_auth_token: Union[bool, str, None] = Non
183195
check=True,
184196
cwd=self.local_dir,
185197
)
186-
# TODO(check if we really want the --force flag)
187-
subprocess.run(
188-
"git checkout origin/main -ft".split(),
198+
199+
output = subprocess.run(
200+
"git branch".split(),
189201
stderr=subprocess.PIPE,
190202
stdout=subprocess.PIPE,
191-
encoding="utf-8",
192203
check=True,
204+
encoding="utf-8",
193205
cwd=self.local_dir,
194206
)
207+
208+
if "main" not in output.stdout.split():
209+
# TODO(check if we really want the --force flag)
210+
subprocess.run(
211+
"git checkout origin/main -ft".split(),
212+
stderr=subprocess.PIPE,
213+
stdout=subprocess.PIPE,
214+
encoding="utf-8",
215+
check=True,
216+
cwd=self.local_dir,
217+
)
218+
195219
except subprocess.CalledProcessError as exc:
196220
raise EnvironmentError(exc.stderr)
197221

tests/test_hubmixin.py

Lines changed: 66 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,66 @@
1+
import unittest
2+
3+
from huggingface_hub.file_download import is_torch_available
4+
from huggingface_hub.hub_mixin import ModelHubMixin
5+
6+
7+
if is_torch_available():
8+
import torch.nn as nn
9+
10+
11+
HUGGINGFACE_ID = "vasudevgupta"
12+
DUMMY_REPO_NAME = "dummy"
13+
14+
15+
def require_torch(test_case):
16+
"""
17+
Decorator marking a test that requires PyTorch.
18+
19+
These tests are skipped when PyTorch isn't installed.
20+
21+
"""
22+
if not is_torch_available():
23+
return unittest.skip("test requires PyTorch")(test_case)
24+
else:
25+
return test_case
26+
27+
28+
@require_torch
29+
class DummyModel(ModelHubMixin):
30+
def __init__(self, **kwargs):
31+
super().__init__()
32+
self.config = kwargs.pop("config", None)
33+
self.l1 = nn.Linear(2, 2)
34+
35+
def forward(self, x):
36+
return self.l1(x)
37+
38+
39+
@require_torch
40+
class DummyModelTest(unittest.TestCase):
41+
def test_save_pretrained(self):
42+
model = DummyModel()
43+
model.save_pretrained(DUMMY_REPO_NAME)
44+
model.save_pretrained(
45+
DUMMY_REPO_NAME, config={"num": 12, "act": "gelu"}, push_to_hub=True
46+
)
47+
model.save_pretrained(
48+
DUMMY_REPO_NAME, config={"num": 24, "act": "relu"}, push_to_hub=True
49+
)
50+
model.save_pretrained(
51+
"dummy-wts", config=None, push_to_hub=True, model_id=DUMMY_REPO_NAME
52+
)
53+
54+
def test_from_pretrained(self):
55+
model = DummyModel()
56+
model.save_pretrained(
57+
DUMMY_REPO_NAME, config={"num": 7, "act": "gelu_fast"}, push_to_hub=True
58+
)
59+
60+
model = DummyModel.from_pretrained(f"{HUGGINGFACE_ID}/{DUMMY_REPO_NAME}@main")
61+
self.assertTrue(model.config == {"num": 7, "act": "gelu_fast"})
62+
63+
def test_push_to_hub(self):
64+
model = DummyModel()
65+
model.save_pretrained("dummy-wts", push_to_hub=False)
66+
model.push_to_hub("dummy-wts", model_id=DUMMY_REPO_NAME)

0 commit comments

Comments
 (0)