22import json
33import os
44import warnings
5+ from functools import wraps
56from pathlib import Path
67from shutil import copytree
78from typing import Any , Dict , List , Optional , Union
1819from .constants import CONFIG_NAME
1920from .hf_api import HfApi
2021from .utils import SoftTemporaryDirectory , logging , validate_hf_hub_args
22+ from .utils ._typing import CallableT
2123
2224
2325logger = logging .get_logger (__name__ )
2426
27+ keras = None
2528if is_tf_available ():
26- import tensorflow as tf # type: ignore
29+ # Depending on which version of TensorFlow is installed, we need to import
30+ # keras from the correct location.
31+ # See https://github.com/tensorflow/tensorflow/releases/tag/v2.16.1.
32+ # Note: saving a keras model only works with Keras<3.0.
33+ try :
34+ import tf_keras as keras # type: ignore
35+ except ImportError :
36+ import tensorflow as tf # type: ignore
37+
38+ keras = tf .keras
39+
40+
41+ def _requires_keras_2_model (fn : CallableT ) -> CallableT :
42+ # Wrapper to raise if user tries to save a Keras 3.x model
43+ @wraps (fn )
44+ def _inner (model , * args , ** kwargs ):
45+ if not hasattr (model , "history" ): # hacky way to check if model is Keras 2.x
46+ raise NotImplementedError (
47+ f"Cannot use '{ fn .__name__ } ': Keras 3.x is not supported."
48+ " Please save models manually and upload them using `upload_folder` or `huggingface-cli upload`."
49+ )
50+ return fn (model , * args , ** kwargs )
51+
52+ return _inner # type: ignore [return-value]
2753
2854
2955def _flatten_dict (dictionary , parent_key = "" ):
@@ -62,15 +88,15 @@ def _create_hyperparameter_table(model):
6288 optimizer_params = model .optimizer .get_config ()
6389 # flatten the configuration
6490 optimizer_params = _flatten_dict (optimizer_params )
65- optimizer_params ["training_precision" ] = tf . keras .mixed_precision .global_policy ().name
91+ optimizer_params ["training_precision" ] = keras .mixed_precision .global_policy ().name
6692 table = "| Hyperparameters | Value |\n | :-- | :-- |\n "
6793 for key , value in optimizer_params .items ():
6894 table += f"| { key } | { value } |\n "
6995 return table
7096
7197
7298def _plot_network (model , save_directory ):
73- tf . keras .utils .plot_model (
99+ keras .utils .plot_model (
74100 model ,
75101 to_file = f"{ save_directory } /model.png" ,
76102 show_shapes = False ,
@@ -127,6 +153,7 @@ def _create_model_card(
127153 readme_path .write_text (model_card )
128154
129155
156+ @_requires_keras_2_model
130157def save_pretrained_keras (
131158 model ,
132159 save_directory : Union [str , Path ],
@@ -161,9 +188,7 @@ def save_pretrained_keras(
161188 model_save_kwargs will be passed to
162189 [`tf.keras.models.save_model()`](https://www.tensorflow.org/api_docs/python/tf/keras/models/save_model).
163190 """
164- if is_tf_available ():
165- import tensorflow as tf
166- else :
191+ if keras is None :
167192 raise ImportError ("Called a Tensorflow-specific function but could not import it." )
168193
169194 if not model .built :
@@ -209,7 +234,7 @@ def save_pretrained_keras(
209234 json .dump (model .history .history , f , indent = 2 , sort_keys = True )
210235
211236 _create_model_card (model , save_directory , plot_model , metadata )
212- tf . keras .models .save_model (model , save_directory , include_optimizer = include_optimizer , ** model_save_kwargs )
237+ keras .models .save_model (model , save_directory , include_optimizer = include_optimizer , ** model_save_kwargs )
213238
214239
215240def from_pretrained_keras (* args , ** kwargs ) -> "KerasModelHubMixin" :
@@ -272,6 +297,7 @@ def from_pretrained_keras(*args, **kwargs) -> "KerasModelHubMixin":
272297
273298
274299@validate_hf_hub_args
300+ @_requires_keras_2_model
275301def push_to_hub_keras (
276302 model ,
277303 repo_id : str ,
@@ -452,9 +478,7 @@ def _from_pretrained(
452478 TODO - Some args above aren't used since we are calling
453479 snapshot_download instead of hf_hub_download.
454480 """
455- if is_tf_available ():
456- import tensorflow as tf
457- else :
481+ if keras is None :
458482 raise ImportError ("Called a TensorFlow-specific function but could not import it." )
459483
460484 # Root is either a local filepath matching model_id or a cached snapshot
@@ -470,7 +494,7 @@ def _from_pretrained(
470494 storage_folder = model_id
471495
472496 # TODO: change this in a future PR. We are not returning a KerasModelHubMixin instance here...
473- model = tf . keras .models .load_model (storage_folder )
497+ model = keras .models .load_model (storage_folder )
474498
475499 # For now, we add a new attribute, config, to store the config loaded from the hub/a local dir.
476500 model .config = config
0 commit comments