1+ import json
12import logging
23import os
34from pathlib import Path
5+ from typing import Any , Dict , Optional , Union
46
5- from huggingface_hub import ModelHubMixin , hf_hub_download
7+ from huggingface_hub import ModelHubMixin
8+ from huggingface_hub .file_download import is_tf_available
9+ from huggingface_hub .snapshot_download import snapshot_download
10+
11+ from .constants import CONFIG_NAME
12+ from .hf_api import HfApi , HfFolder
13+ from .repository import Repository
614
715
816logger = logging .getLogger (__name__ )
917
18+ if is_tf_available ():
19+ import tensorflow as tf
1020
11- class KerasModelHubMixin (ModelHubMixin ):
1221
13- _CONFIG_NAME = "config.json"
14- _WEIGHTS_NAME = "tf_model.h5"
22+ def save_pretrained_keras (
23+ model , save_directory : str , config : Optional [Dict [str , Any ]] = None
24+ ):
25+ """Saves a Keras model to save_directory in SavedModel format. Use this if you're using the Functional or Sequential APIs.
26+
27+ model:
28+ The Keras model you'd like to save. The model must be compiled and built.
29+ save_directory (:obj:`str`):
30+ Specify directory in which you want to save the Keras model.
31+ config (:obj:`dict`, `optional`):
32+ Configuration object to be saved alongside the model weights.
33+ """
34+
35+ if not model .built :
36+ raise ValueError ("Model should be built before trying to save" )
37+
38+ os .makedirs (save_directory , exist_ok = True )
39+
40+ # saving config
41+ if config :
42+ if not isinstance (config , dict ):
43+ raise RuntimeError (
44+ f"Provided config to save_pretrained_keras should be a dict. Got: '{ type (config )} '"
45+ )
46+ path = os .path .join (save_directory , CONFIG_NAME )
47+ with open (path , "w" ) as f :
48+ json .dump (config , f )
49+
50+ tf .keras .models .save_model (model , save_directory )
51+
52+
53+ def from_pretrained_keras (* args , ** kwargs ):
54+ return KerasModelHubMixin .from_pretrained (* args , ** kwargs )
55+
56+
57+ def push_to_hub_keras (
58+ model ,
59+ repo_path_or_name : Optional [str ] = None ,
60+ repo_url : Optional [str ] = None ,
61+ commit_message : Optional [str ] = "Add model" ,
62+ organization : Optional [str ] = None ,
63+ private : Optional [bool ] = None ,
64+ api_endpoint : Optional [str ] = None ,
65+ use_auth_token : Optional [Union [bool , str ]] = True ,
66+ git_user : Optional [str ] = None ,
67+ git_email : Optional [str ] = None ,
68+ config : Optional [dict ] = None ,
69+ ):
70+ """
71+ Upload model checkpoint or tokenizer files to the 🤗 Model Hub while synchronizing a local clone of the repo in
72+ :obj:`repo_path_or_name`.
73+
74+ Parameters:
75+ model:
76+ The Keras model you'd like to push to the hub. It model must be compiled and built.
77+ repo_path_or_name (:obj:`str`, `optional`):
78+ Can either be a repository name for your model or tokenizer in the Hub or a path to a local folder (in
79+ which case the repository will have the name of that local folder). If not specified, will default to
80+ the name given by :obj:`repo_url` and a local directory with that name will be created.
81+ repo_url (:obj:`str`, `optional`):
82+ Specify this in case you want to push to an existing repository in the hub. If unspecified, a new
83+ repository will be created in your namespace (unless you specify an :obj:`organization`) with
84+ :obj:`repo_name`.
85+ commit_message (:obj:`str`, `optional`):
86+ Message to commit while pushing. Will default to :obj:`"add model"`.
87+ organization (:obj:`str`, `optional`):
88+ Organization in which you want to push your model or tokenizer (you must be a member of this
89+ organization).
90+ private (:obj:`bool`, `optional`):
91+ Whether or not the repository created should be private (requires a paying subscription).
92+ api_endpoint (:obj:`str`, `optional`):
93+ The API endpoint to use when pushing the model to the hub.
94+ use_auth_token (:obj:`bool` or :obj:`str`, `optional`):
95+ The token to use as HTTP bearer authorization for remote files. If :obj:`True`, will use the token
96+ generated when running :obj:`transformers-cli login` (stored in :obj:`~/.huggingface`). Will default to
97+ :obj:`True`.
98+ git_user (``str``, `optional`):
99+ will override the ``git config user.name`` for committing and pushing files to the hub.
100+ git_email (``str``, `optional`):
101+ will override the ``git config user.email`` for committing and pushing files to the hub.
102+ config (:obj:`dict`, `optional`):
103+ Configuration object to be saved alongside the model weights.
104+
105+ Returns:
106+ The url of the commit of your model in the given repository.
107+ """
108+
109+ if repo_path_or_name is None and repo_url is None :
110+ raise ValueError ("You need to specify a `repo_path_or_name` or a `repo_url`." )
111+
112+ if isinstance (use_auth_token , bool ) and use_auth_token :
113+ token = HfFolder .get_token ()
114+ elif isinstance (use_auth_token , str ):
115+ token = use_auth_token
116+ else :
117+ token = None
118+
119+ if token is None :
120+ raise ValueError (
121+ "You must login to the Hugging Face hub on this computer by typing `huggingface-cli login` and "
122+ "entering your credentials to use `use_auth_token=True`. Alternatively, you can pass your own "
123+ "token as the `use_auth_token` argument."
124+ )
125+
126+ if repo_path_or_name is None :
127+ repo_path_or_name = repo_url .split ("/" )[- 1 ]
128+
129+ # If no URL is passed and there's no path to a directory containing files, create a repo
130+ if repo_url is None and not os .path .exists (repo_path_or_name ):
131+ repo_name = Path (repo_path_or_name ).name
132+ repo_url = HfApi (endpoint = api_endpoint ).create_repo (
133+ token ,
134+ repo_name ,
135+ organization = organization ,
136+ private = private ,
137+ repo_type = None ,
138+ exist_ok = True ,
139+ )
140+
141+ repo = Repository (
142+ repo_path_or_name ,
143+ clone_from = repo_url ,
144+ use_auth_token = use_auth_token ,
145+ git_user = git_user ,
146+ git_email = git_email ,
147+ )
148+ repo .git_pull (rebase = True )
149+
150+ save_pretrained_keras (model , repo_path_or_name , config = config )
151+
152+ # Commit and push!
153+ repo .git_add (auto_lfs_track = True )
154+ repo .git_commit (commit_message )
155+ return repo .git_push ()
156+
15157
158+ class KerasModelHubMixin (ModelHubMixin ):
16159 def __init__ (self , * args , ** kwargs ):
17160 """
18161 Mix this class with your keras-model class for ease process of saving & loading from huggingface-hub
19162
20- NOTE - Dummy Inputs are required to save/load models using this mixin. When saving, you are required to either:
21-
22- 1. Assign an attribute to your class, self.dummy_inputs, that defines inputs to be passed to the model's call
23- function to build the model.
24- 2. Pass the dummy_inputs kwarg to save_pretrained. We will save this along with the model (as if it were an attribute).
25-
26163 Example::
27164
28165 >>> from huggingface_hub import KerasModelHubMixin
@@ -36,33 +173,22 @@ def __init__(self, *args, **kwargs):
36173 ... def call(self, ...)
37174 ... return ...
38175
176+ >>> # Init and compile the model as you normally would
39177 >>> model = MyModel()
40- >>> model.save_pretrained("mymodel", push_to_hub=False) # Saving model weights in the directory
41- >>> model.push_to_hub("mymodel", "model-1") # Pushing model-weights to hf-hub
178+ >>> model.compile(...)
179+ >>> # Build the graph by training it or passing dummy inputs
180+ >>> _ = model(model.dummy_inputs)
181+ >>> # You can save your model like this
182+ >>> model.save_pretrained("local_model_dir/", push_to_hub=False)
183+ >>> # Or, you can push to a new public model repo like this
184+ >>> model.push_to_hub("super-cool-model", git_user="your-hf-username", git_email="[email protected] ") 42185
43186 >>> # Downloading weights from hf-hub & model will be initialized from those weights
44187 >>> model = MyModel.from_pretrained("username/mymodel@main")
45188 """
46189
47- def _save_pretrained (self , save_directory , dummy_inputs = None , ** kwargs ):
48-
49- dummy_inputs = (
50- dummy_inputs
51- if dummy_inputs is not None
52- else getattr (self , "dummy_inputs" , None )
53- )
54-
55- if dummy_inputs is None :
56- raise RuntimeError (
57- "You must either provide dummy inputs or have them assigned as an attribute of this model"
58- )
59-
60- _ = self (dummy_inputs , training = False )
61-
62- save_directory = Path (save_directory )
63- model_file = save_directory / self ._WEIGHTS_NAME
64- self .save_weights (model_file )
65- logger .info (f"Model weights saved in { model_file } " )
190+ def _save_pretrained (self , save_directory ):
191+ save_pretrained_keras (self , save_directory )
66192
67193 @classmethod
68194 def _from_pretrained (
@@ -75,34 +201,27 @@ def _from_pretrained(
75201 resume_download ,
76202 local_files_only ,
77203 use_auth_token ,
78- by_name = False ,
79204 ** model_kwargs ,
80205 ):
81- if os .path .isdir (model_id ):
82- print ("Loading weights from local directory" )
83- model_file = os .path .join (model_id , cls ._WEIGHTS_NAME )
84- else :
85- model_file = hf_hub_download (
86- repo_id = model_id ,
87- filename = cls ._WEIGHTS_NAME ,
88- revision = revision ,
89- cache_dir = cache_dir ,
90- force_download = force_download ,
91- proxies = proxies ,
92- resume_download = resume_download ,
93- use_auth_token = use_auth_token ,
94- local_files_only = local_files_only ,
95- )
206+ """Here we just call from_pretrained_keras function so both the mixin and functional APIs stay in sync.
96207
97- model = cls (** model_kwargs )
208+ TODO - Some args above aren't used since we are calling snapshot_download instead of hf_hub_download.
209+ """
98210
99- if hasattr ( model , "dummy_inputs" ) and model . dummy_inputs is not None :
100- raise ValueError ( "Model must have a dummy_inputs attribute" )
211+ # TODO - Figure out what to do about these config values. Config is not going to be needed to load model
212+ cfg = model_kwargs . pop ( "config" , None )
101213
102- _ = model (model .dummy_inputs , training = False )
214+ # Root is either a local filepath matching model_id or a cached snapshot
215+ if not os .path .isdir (model_id ):
216+ storage_folder = snapshot_download (
217+ repo_id = model_id , revision = revision , cache_dir = cache_dir
218+ )
219+ else :
220+ storage_folder = model_id
103221
104- model . load_weights ( model_file , by_name = by_name )
222+ model = tf . keras . models . load_model ( storage_folder , ** model_kwargs )
105223
106- _ = model (model .dummy_inputs , training = False )
224+ # For now, we add a new attribute, config, to store the config loaded from the hub/a local dir.
225+ model .config = cfg
107226
108227 return model
0 commit comments