Skip to content

Commit 07b0ad0

Browse files
naterawLysandreJik
andauthored
Update Keras Mixin (#284)
* 🎨 add Keras mixin to root level __init__.py * 🎨 remove kwargs from _save_pretrained * 🚧 wip open heart surgery on keras mixin/utils * 🚧 wip * ✅ update tests * 💄 style * ✅ work on tests * 🐛 require tf in keras functional tests * 🐛 add placeholders for dummy models * ✅ Update test params to use cloned model * ✅ fix tests * ✅ move tf clone model inside test functions * 💄 style * ✅ test save does not work if keras model not built * 🎨 add explicit endpoint * 🔥 remove class attrs I decided not to use * 🚧 wip * ✅ clean up keras integration tests * 🎨 make sure functional API includes config * 👷 add tensorflow build to CI * Apply suggestions from code review Co-authored-by: Lysandre Debut <[email protected]> * 🚧 wip * 🚧 wip * 🚚 rename model.hf_config to model.config * 🔥 rm some code * ✏️ Updates from code review Co-authored-by: Lysandre Debut <[email protected]>
1 parent b26da2a commit 07b0ad0

File tree

7 files changed

+480
-210
lines changed

7 files changed

+480
-210
lines changed

.github/workflows/python-tests.yml

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -57,6 +57,23 @@ jobs:
5757
5858
- run: pytest -sv ./tests/
5959

60+
build_tensorflow:
61+
runs-on: ubuntu-latest
62+
63+
steps:
64+
- uses: actions/checkout@v2
65+
66+
- name: Set up Python ${{ matrix.python-version }}
67+
uses: actions/setup-python@v2
68+
with:
69+
python-version: ${{ matrix.python-version }}
70+
71+
- name: Install dependencies
72+
run: |
73+
pip install --upgrade pip
74+
pip install .[testing,tensorflow]
75+
76+
- run: pytest -sv ./tests/
6077

6178
tests_lfs:
6279
runs-on: ubuntu-latest

setup.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,10 @@ def get_version() -> str:
2727
"torch",
2828
]
2929

30+
extras["tensorflow"] = [
31+
"tensorflow",
32+
]
33+
3034
extras["testing"] = [
3135
"pytest",
3236
"datasets",

src/huggingface_hub/__init__.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,12 @@
3434
from .hf_api import HfApi, HfFolder, repo_type_and_id_from_hf_id
3535
from .hub_mixin import ModelHubMixin, PyTorchModelHubMixin
3636
from .inference_api import InferenceApi
37+
from .keras_mixin import (
38+
KerasModelHubMixin,
39+
from_pretrained_keras,
40+
push_to_hub_keras,
41+
save_pretrained_keras,
42+
)
3743
from .repository import Repository
3844
from .snapshot_download import snapshot_download
3945
from .utils import logging

src/huggingface_hub/hub_mixin.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -58,12 +58,12 @@ def save_pretrained(
5858
json.dump(config, f)
5959

6060
# saving model weights/files
61-
self._save_pretrained(save_directory, **kwargs)
61+
self._save_pretrained(save_directory)
6262

6363
if push_to_hub:
6464
return self.push_to_hub(save_directory, **kwargs)
6565

66-
def _save_pretrained(self, save_directory, **kwargs):
66+
def _save_pretrained(self, save_directory):
6767
"""
6868
Overwrite this method in subclass to define how to save your model.
6969
"""
@@ -144,10 +144,10 @@ def from_pretrained(
144144
local_files_only=local_files_only,
145145
)
146146
except requests.exceptions.RequestException:
147-
logger.warning("config.json NOT FOUND in HuggingFace Hub")
147+
logger.warning(f"{CONFIG_NAME} not found in HuggingFace Hub")
148148
config_file = None
149149

150-
if config_file is not None:
150+
if config_file is not None and config_file.endswith(".json"):
151151
with open(config_file, "r", encoding="utf-8") as f:
152152
config = json.load(f)
153153
model_kwargs.update({"config": config})

src/huggingface_hub/keras_mixin.py

Lines changed: 172 additions & 53 deletions
Original file line numberDiff line numberDiff line change
@@ -1,28 +1,165 @@
1+
import json
12
import logging
23
import os
34
from pathlib import Path
5+
from typing import Any, Dict, Optional, Union
46

5-
from huggingface_hub import ModelHubMixin, hf_hub_download
7+
from huggingface_hub import ModelHubMixin
8+
from huggingface_hub.file_download import is_tf_available
9+
from huggingface_hub.snapshot_download import snapshot_download
10+
11+
from .constants import CONFIG_NAME
12+
from .hf_api import HfApi, HfFolder
13+
from .repository import Repository
614

715

816
logger = logging.getLogger(__name__)
917

18+
if is_tf_available():
19+
import tensorflow as tf
1020

11-
class KerasModelHubMixin(ModelHubMixin):
1221

13-
_CONFIG_NAME = "config.json"
14-
_WEIGHTS_NAME = "tf_model.h5"
22+
def save_pretrained_keras(
23+
model, save_directory: str, config: Optional[Dict[str, Any]] = None
24+
):
25+
"""Saves a Keras model to save_directory in SavedModel format. Use this if you're using the Functional or Sequential APIs.
26+
27+
model:
28+
The Keras model you'd like to save. The model must be compiled and built.
29+
save_directory (:obj:`str`):
30+
Specify directory in which you want to save the Keras model.
31+
config (:obj:`dict`, `optional`):
32+
Configuration object to be saved alongside the model weights.
33+
"""
34+
35+
if not model.built:
36+
raise ValueError("Model should be built before trying to save")
37+
38+
os.makedirs(save_directory, exist_ok=True)
39+
40+
# saving config
41+
if config:
42+
if not isinstance(config, dict):
43+
raise RuntimeError(
44+
f"Provided config to save_pretrained_keras should be a dict. Got: '{type(config)}'"
45+
)
46+
path = os.path.join(save_directory, CONFIG_NAME)
47+
with open(path, "w") as f:
48+
json.dump(config, f)
49+
50+
tf.keras.models.save_model(model, save_directory)
51+
52+
53+
def from_pretrained_keras(*args, **kwargs):
54+
return KerasModelHubMixin.from_pretrained(*args, **kwargs)
55+
56+
57+
def push_to_hub_keras(
58+
model,
59+
repo_path_or_name: Optional[str] = None,
60+
repo_url: Optional[str] = None,
61+
commit_message: Optional[str] = "Add model",
62+
organization: Optional[str] = None,
63+
private: Optional[bool] = None,
64+
api_endpoint: Optional[str] = None,
65+
use_auth_token: Optional[Union[bool, str]] = True,
66+
git_user: Optional[str] = None,
67+
git_email: Optional[str] = None,
68+
config: Optional[dict] = None,
69+
):
70+
"""
71+
Upload model checkpoint or tokenizer files to the 🤗 Model Hub while synchronizing a local clone of the repo in
72+
:obj:`repo_path_or_name`.
73+
74+
Parameters:
75+
model:
76+
The Keras model you'd like to push to the hub. It model must be compiled and built.
77+
repo_path_or_name (:obj:`str`, `optional`):
78+
Can either be a repository name for your model or tokenizer in the Hub or a path to a local folder (in
79+
which case the repository will have the name of that local folder). If not specified, will default to
80+
the name given by :obj:`repo_url` and a local directory with that name will be created.
81+
repo_url (:obj:`str`, `optional`):
82+
Specify this in case you want to push to an existing repository in the hub. If unspecified, a new
83+
repository will be created in your namespace (unless you specify an :obj:`organization`) with
84+
:obj:`repo_name`.
85+
commit_message (:obj:`str`, `optional`):
86+
Message to commit while pushing. Will default to :obj:`"add model"`.
87+
organization (:obj:`str`, `optional`):
88+
Organization in which you want to push your model or tokenizer (you must be a member of this
89+
organization).
90+
private (:obj:`bool`, `optional`):
91+
Whether or not the repository created should be private (requires a paying subscription).
92+
api_endpoint (:obj:`str`, `optional`):
93+
The API endpoint to use when pushing the model to the hub.
94+
use_auth_token (:obj:`bool` or :obj:`str`, `optional`):
95+
The token to use as HTTP bearer authorization for remote files. If :obj:`True`, will use the token
96+
generated when running :obj:`transformers-cli login` (stored in :obj:`~/.huggingface`). Will default to
97+
:obj:`True`.
98+
git_user (``str``, `optional`):
99+
will override the ``git config user.name`` for committing and pushing files to the hub.
100+
git_email (``str``, `optional`):
101+
will override the ``git config user.email`` for committing and pushing files to the hub.
102+
config (:obj:`dict`, `optional`):
103+
Configuration object to be saved alongside the model weights.
104+
105+
Returns:
106+
The url of the commit of your model in the given repository.
107+
"""
108+
109+
if repo_path_or_name is None and repo_url is None:
110+
raise ValueError("You need to specify a `repo_path_or_name` or a `repo_url`.")
111+
112+
if isinstance(use_auth_token, bool) and use_auth_token:
113+
token = HfFolder.get_token()
114+
elif isinstance(use_auth_token, str):
115+
token = use_auth_token
116+
else:
117+
token = None
118+
119+
if token is None:
120+
raise ValueError(
121+
"You must login to the Hugging Face hub on this computer by typing `huggingface-cli login` and "
122+
"entering your credentials to use `use_auth_token=True`. Alternatively, you can pass your own "
123+
"token as the `use_auth_token` argument."
124+
)
125+
126+
if repo_path_or_name is None:
127+
repo_path_or_name = repo_url.split("/")[-1]
128+
129+
# If no URL is passed and there's no path to a directory containing files, create a repo
130+
if repo_url is None and not os.path.exists(repo_path_or_name):
131+
repo_name = Path(repo_path_or_name).name
132+
repo_url = HfApi(endpoint=api_endpoint).create_repo(
133+
token,
134+
repo_name,
135+
organization=organization,
136+
private=private,
137+
repo_type=None,
138+
exist_ok=True,
139+
)
140+
141+
repo = Repository(
142+
repo_path_or_name,
143+
clone_from=repo_url,
144+
use_auth_token=use_auth_token,
145+
git_user=git_user,
146+
git_email=git_email,
147+
)
148+
repo.git_pull(rebase=True)
149+
150+
save_pretrained_keras(model, repo_path_or_name, config=config)
151+
152+
# Commit and push!
153+
repo.git_add(auto_lfs_track=True)
154+
repo.git_commit(commit_message)
155+
return repo.git_push()
156+
15157

158+
class KerasModelHubMixin(ModelHubMixin):
16159
def __init__(self, *args, **kwargs):
17160
"""
18161
Mix this class with your keras-model class for ease process of saving & loading from huggingface-hub
19162
20-
NOTE - Dummy Inputs are required to save/load models using this mixin. When saving, you are required to either:
21-
22-
1. Assign an attribute to your class, self.dummy_inputs, that defines inputs to be passed to the model's call
23-
function to build the model.
24-
2. Pass the dummy_inputs kwarg to save_pretrained. We will save this along with the model (as if it were an attribute).
25-
26163
Example::
27164
28165
>>> from huggingface_hub import KerasModelHubMixin
@@ -36,33 +173,22 @@ def __init__(self, *args, **kwargs):
36173
... def call(self, ...)
37174
... return ...
38175
176+
>>> # Init and compile the model as you normally would
39177
>>> model = MyModel()
40-
>>> model.save_pretrained("mymodel", push_to_hub=False) # Saving model weights in the directory
41-
>>> model.push_to_hub("mymodel", "model-1") # Pushing model-weights to hf-hub
178+
>>> model.compile(...)
179+
>>> # Build the graph by training it or passing dummy inputs
180+
>>> _ = model(model.dummy_inputs)
181+
>>> # You can save your model like this
182+
>>> model.save_pretrained("local_model_dir/", push_to_hub=False)
183+
>>> # Or, you can push to a new public model repo like this
184+
>>> model.push_to_hub("super-cool-model", git_user="your-hf-username", git_email="[email protected]")
42185
43186
>>> # Downloading weights from hf-hub & model will be initialized from those weights
44187
>>> model = MyModel.from_pretrained("username/mymodel@main")
45188
"""
46189

47-
def _save_pretrained(self, save_directory, dummy_inputs=None, **kwargs):
48-
49-
dummy_inputs = (
50-
dummy_inputs
51-
if dummy_inputs is not None
52-
else getattr(self, "dummy_inputs", None)
53-
)
54-
55-
if dummy_inputs is None:
56-
raise RuntimeError(
57-
"You must either provide dummy inputs or have them assigned as an attribute of this model"
58-
)
59-
60-
_ = self(dummy_inputs, training=False)
61-
62-
save_directory = Path(save_directory)
63-
model_file = save_directory / self._WEIGHTS_NAME
64-
self.save_weights(model_file)
65-
logger.info(f"Model weights saved in {model_file}")
190+
def _save_pretrained(self, save_directory):
191+
save_pretrained_keras(self, save_directory)
66192

67193
@classmethod
68194
def _from_pretrained(
@@ -75,34 +201,27 @@ def _from_pretrained(
75201
resume_download,
76202
local_files_only,
77203
use_auth_token,
78-
by_name=False,
79204
**model_kwargs,
80205
):
81-
if os.path.isdir(model_id):
82-
print("Loading weights from local directory")
83-
model_file = os.path.join(model_id, cls._WEIGHTS_NAME)
84-
else:
85-
model_file = hf_hub_download(
86-
repo_id=model_id,
87-
filename=cls._WEIGHTS_NAME,
88-
revision=revision,
89-
cache_dir=cache_dir,
90-
force_download=force_download,
91-
proxies=proxies,
92-
resume_download=resume_download,
93-
use_auth_token=use_auth_token,
94-
local_files_only=local_files_only,
95-
)
206+
"""Here we just call from_pretrained_keras function so both the mixin and functional APIs stay in sync.
96207
97-
model = cls(**model_kwargs)
208+
TODO - Some args above aren't used since we are calling snapshot_download instead of hf_hub_download.
209+
"""
98210

99-
if hasattr(model, "dummy_inputs") and model.dummy_inputs is not None:
100-
raise ValueError("Model must have a dummy_inputs attribute")
211+
# TODO - Figure out what to do about these config values. Config is not going to be needed to load model
212+
cfg = model_kwargs.pop("config", None)
101213

102-
_ = model(model.dummy_inputs, training=False)
214+
# Root is either a local filepath matching model_id or a cached snapshot
215+
if not os.path.isdir(model_id):
216+
storage_folder = snapshot_download(
217+
repo_id=model_id, revision=revision, cache_dir=cache_dir
218+
)
219+
else:
220+
storage_folder = model_id
103221

104-
model.load_weights(model_file, by_name=by_name)
222+
model = tf.keras.models.load_model(storage_folder, **model_kwargs)
105223

106-
_ = model(model.dummy_inputs, training=False)
224+
# For now, we add a new attribute, config, to store the config loaded from the hub/a local dir.
225+
model.config = cfg
107226

108227
return model

0 commit comments

Comments
 (0)