Skip to content

Commit af792e8

Browse files
authored
Explicitly fail on Keras3 (#2107)
* Explicitly fail on Keras3 * code quality * try to load from tf_keras * allow saving if tf_keras is installed * lint
1 parent 6d69291 commit af792e8

File tree

4 files changed

+57
-13
lines changed

4 files changed

+57
-13
lines changed

.github/workflows/python-tests.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -75,7 +75,7 @@ jobs:
7575
tensorflow)
7676
sudo apt update
7777
sudo apt install -y graphviz
78-
uv pip install "huggingface_hub[tensorflow] @ ."
78+
uv pip install "huggingface_hub[tensorflow-testing] @ ."
7979
;;
8080
8181
esac

setup.py

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -45,7 +45,16 @@ def get_version() -> str:
4545
"fastcore>=1.3.27",
4646
]
4747

48-
extras["tensorflow"] = ["tensorflow", "pydot", "graphviz"]
48+
extras["tensorflow"] = [
49+
"tensorflow",
50+
"pydot",
51+
"graphviz",
52+
]
53+
54+
extras["tensorflow-testing"] = [
55+
"tensorflow",
56+
"keras<3.0",
57+
]
4958

5059

5160
extras["testing"] = (

src/huggingface_hub/keras_mixin.py

Lines changed: 35 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
import json
33
import os
44
import warnings
5+
from functools import wraps
56
from pathlib import Path
67
from shutil import copytree
78
from typing import Any, Dict, List, Optional, Union
@@ -18,12 +19,37 @@
1819
from .constants import CONFIG_NAME
1920
from .hf_api import HfApi
2021
from .utils import SoftTemporaryDirectory, logging, validate_hf_hub_args
22+
from .utils._typing import CallableT
2123

2224

2325
logger = logging.get_logger(__name__)
2426

27+
keras = None
2528
if is_tf_available():
26-
import tensorflow as tf # type: ignore
29+
# Depending on which version of TensorFlow is installed, we need to import
30+
# keras from the correct location.
31+
# See https://github.com/tensorflow/tensorflow/releases/tag/v2.16.1.
32+
# Note: saving a keras model only works with Keras<3.0.
33+
try:
34+
import tf_keras as keras # type: ignore
35+
except ImportError:
36+
import tensorflow as tf # type: ignore
37+
38+
keras = tf.keras
39+
40+
41+
def _requires_keras_2_model(fn: CallableT) -> CallableT:
42+
# Wrapper to raise if user tries to save a Keras 3.x model
43+
@wraps(fn)
44+
def _inner(model, *args, **kwargs):
45+
if not hasattr(model, "history"): # hacky way to check if model is Keras 2.x
46+
raise NotImplementedError(
47+
f"Cannot use '{fn.__name__}': Keras 3.x is not supported."
48+
" Please save models manually and upload them using `upload_folder` or `huggingface-cli upload`."
49+
)
50+
return fn(model, *args, **kwargs)
51+
52+
return _inner # type: ignore [return-value]
2753

2854

2955
def _flatten_dict(dictionary, parent_key=""):
@@ -62,15 +88,15 @@ def _create_hyperparameter_table(model):
6288
optimizer_params = model.optimizer.get_config()
6389
# flatten the configuration
6490
optimizer_params = _flatten_dict(optimizer_params)
65-
optimizer_params["training_precision"] = tf.keras.mixed_precision.global_policy().name
91+
optimizer_params["training_precision"] = keras.mixed_precision.global_policy().name
6692
table = "| Hyperparameters | Value |\n| :-- | :-- |\n"
6793
for key, value in optimizer_params.items():
6894
table += f"| {key} | {value} |\n"
6995
return table
7096

7197

7298
def _plot_network(model, save_directory):
73-
tf.keras.utils.plot_model(
99+
keras.utils.plot_model(
74100
model,
75101
to_file=f"{save_directory}/model.png",
76102
show_shapes=False,
@@ -127,6 +153,7 @@ def _create_model_card(
127153
readme_path.write_text(model_card)
128154

129155

156+
@_requires_keras_2_model
130157
def save_pretrained_keras(
131158
model,
132159
save_directory: Union[str, Path],
@@ -161,9 +188,7 @@ def save_pretrained_keras(
161188
model_save_kwargs will be passed to
162189
[`tf.keras.models.save_model()`](https://www.tensorflow.org/api_docs/python/tf/keras/models/save_model).
163190
"""
164-
if is_tf_available():
165-
import tensorflow as tf
166-
else:
191+
if keras is None:
167192
raise ImportError("Called a Tensorflow-specific function but could not import it.")
168193

169194
if not model.built:
@@ -209,7 +234,7 @@ def save_pretrained_keras(
209234
json.dump(model.history.history, f, indent=2, sort_keys=True)
210235

211236
_create_model_card(model, save_directory, plot_model, metadata)
212-
tf.keras.models.save_model(model, save_directory, include_optimizer=include_optimizer, **model_save_kwargs)
237+
keras.models.save_model(model, save_directory, include_optimizer=include_optimizer, **model_save_kwargs)
213238

214239

215240
def from_pretrained_keras(*args, **kwargs) -> "KerasModelHubMixin":
@@ -272,6 +297,7 @@ def from_pretrained_keras(*args, **kwargs) -> "KerasModelHubMixin":
272297

273298

274299
@validate_hf_hub_args
300+
@_requires_keras_2_model
275301
def push_to_hub_keras(
276302
model,
277303
repo_id: str,
@@ -452,9 +478,7 @@ def _from_pretrained(
452478
TODO - Some args above aren't used since we are calling
453479
snapshot_download instead of hf_hub_download.
454480
"""
455-
if is_tf_available():
456-
import tensorflow as tf
457-
else:
481+
if keras is None:
458482
raise ImportError("Called a TensorFlow-specific function but could not import it.")
459483

460484
# Root is either a local filepath matching model_id or a cached snapshot
@@ -470,7 +494,7 @@ def _from_pretrained(
470494
storage_folder = model_id
471495

472496
# TODO: change this in a future PR. We are not returning a KerasModelHubMixin instance here...
473-
model = tf.keras.models.load_model(storage_folder)
497+
model = keras.models.load_model(storage_folder)
474498

475499
# For now, we add a new attribute, config, to store the config loaded from the hub/a local dir.
476500
model.config = config

src/huggingface_hub/utils/_runtime.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,7 @@
3535
"graphviz": {"graphviz"},
3636
"hf_transfer": {"hf_transfer"},
3737
"jinja": {"Jinja2"},
38+
"keras": {"keras"},
3839
"minijinja": {"minijinja"},
3940
"numpy": {"numpy"},
4041
"pillow": {"Pillow"},
@@ -140,6 +141,15 @@ def get_hf_transfer_version() -> str:
140141
return _get_version("hf_transfer")
141142

142143

144+
# keras
145+
def is_keras_available() -> bool:
146+
return is_package_available("keras")
147+
148+
149+
def get_keras_version() -> str:
150+
return _get_version("keras")
151+
152+
143153
# Minijinja
144154
def is_minijinja_available() -> bool:
145155
return is_package_available("minijinja")
@@ -332,6 +342,7 @@ def dump_environment_info() -> Dict[str, Any]:
332342
info["Torch"] = get_torch_version()
333343
info["Jinja2"] = get_jinja_version()
334344
info["Graphviz"] = get_graphviz_version()
345+
info["keras"] = get_keras_version()
335346
info["Pydot"] = get_pydot_version()
336347
info["Pillow"] = get_pillow_version()
337348
info["hf_transfer"] = get_hf_transfer_version()

0 commit comments

Comments
 (0)