Skip to content

Commit 697b20f

Browse files
feat: save/load TF Keras model (#89)
* feat: save/load TF Keras model * tensorflow >=2.0 --------- Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
1 parent 3c663b8 commit 697b20f

File tree

7 files changed

+140
-16
lines changed

7 files changed

+140
-16
lines changed

README.md

Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -123,6 +123,43 @@ trainer.fit(BoringModel(), ckpt_path=checkpoint_path)
123123

124124
</details>
125125

126+
<details>
127+
<summary>TensorFlow / Keras</summary>
128+
129+
Save model:
130+
131+
```python
132+
from tensorflow import keras
133+
134+
from litmodels import upload_model
135+
136+
# Define the model
137+
model = keras.Sequential(
138+
[
139+
keras.layers.Dense(10, input_shape=(784,), name="dense_1"),
140+
keras.layers.Dense(10, name="dense_2"),
141+
]
142+
)
143+
144+
# Compile the model
145+
model.compile(optimizer="adam", loss="categorical_crossentropy")
146+
147+
# Save the model
148+
upload_model("lightning-ai/jirka/sample-tf-keras-model", model=model)
149+
```
150+
151+
Load model:
152+
153+
```python
154+
from litmodels import load_model
155+
156+
model_ = load_model(
157+
"lightning-ai/jirka/sample-tf-keras-model", download_dir="./my-model"
158+
)
159+
```
160+
161+
</details>
162+
126163
<details>
127164
<summary>SKLearn</summary>
128165

_requirements/test.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,3 +6,4 @@ pytest-mock
66
pytorch-lightning >=2.0
77
scikit-learn >=1.0
88
huggingface-hub >=0.29.0
9+
tensorflow >=2.0

examples/demo-tensorflow-keras.py

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,19 @@
1+
from tensorflow import keras
2+
3+
from litmodels import load_model, upload_model
4+
5+
if __name__ == "__main__":
6+
# Define the model
7+
model = keras.Sequential([
8+
keras.layers.Dense(10, input_shape=(784,), name="dense_1"),
9+
keras.layers.Dense(10, name="dense_2"),
10+
])
11+
12+
# Compile the model
13+
model.compile(optimizer="adam", loss="categorical_crossentropy")
14+
15+
# Save the model
16+
upload_model("lightning-ai/jirka/sample-tf-keras-model", model=model)
17+
18+
# Load the model
19+
model_ = load_model("lightning-ai/jirka/sample-tf-keras-model", download_dir="./my-model")

src/litmodels/io/gateway.py

Lines changed: 14 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -3,15 +3,14 @@
33
from pathlib import Path
44
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Union
55

6-
from lightning_utilities import module_available
7-
86
from litmodels.io.cloud import download_model_files, upload_model_files
9-
from litmodels.io.utils import dump_pickle, load_pickle
7+
from litmodels.io.utils import _KERAS_AVAILABLE, _PYTORCH_AVAILABLE, dump_pickle, load_pickle
108

11-
if module_available("torch"):
9+
if _PYTORCH_AVAILABLE:
1210
import torch
13-
else:
14-
torch = None
11+
12+
if _KERAS_AVAILABLE:
13+
from tensorflow import keras
1514

1615
if TYPE_CHECKING:
1716
from lightning_sdk.models import UploadedModelInfo
@@ -48,12 +47,15 @@ def upload_model(
4847
# if LightningModule and isinstance(model, LightningModule):
4948
# path = os.path.join(staging_dir, f"{model.__class__.__name__}.ckpt")
5049
# model.save_checkpoint(path)
51-
elif torch and isinstance(model, torch.jit.ScriptModule):
50+
elif _PYTORCH_AVAILABLE and isinstance(model, torch.jit.ScriptModule):
5251
path = os.path.join(staging_dir, f"{model.__class__.__name__}.ts")
5352
model.save(path)
54-
elif torch and isinstance(model, torch.nn.Module):
53+
elif _PYTORCH_AVAILABLE and isinstance(model, torch.nn.Module):
5554
path = os.path.join(staging_dir, f"{model.__class__.__name__}.pth")
5655
torch.save(model.state_dict(), path)
56+
elif _KERAS_AVAILABLE and isinstance(model, keras.models.Model):
57+
path = os.path.join(staging_dir, f"{model.__class__.__name__}.keras")
58+
model.save(path)
5759
else:
5860
path = os.path.join(staging_dir, f"{model.__class__.__name__}.pkl")
5961
dump_pickle(model=model, path=path)
@@ -110,8 +112,10 @@ def load_model(name: str, download_dir: str = ".") -> Any:
110112
if len(download_paths) > 1:
111113
raise NotImplementedError("Downloaded model with multiple files is not supported yet.")
112114
model_path = Path(download_dir) / download_paths[0]
113-
if model_path.suffix.lower() == ".pkl":
114-
return load_pickle(path=model_path)
115115
if model_path.suffix.lower() == ".ts":
116116
return torch.jit.load(model_path)
117+
if model_path.suffix.lower() == ".keras":
118+
return keras.models.load_model(model_path)
119+
if model_path.suffix.lower() == ".pkl":
120+
return load_pickle(path=model_path)
117121
raise NotImplementedError(f"Loading model from {model_path.suffix} is not supported yet.")

src/litmodels/io/utils.py

Lines changed: 9 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -3,11 +3,15 @@
33
from typing import Any, Union
44

55
from lightning_utilities import module_available
6+
from lightning_utilities.core.imports import RequirementCache
67

7-
if module_available("joblib"):
8+
_JOBLIB_AVAILABLE = module_available("joblib")
9+
_PYTORCH_AVAILABLE = module_available("torch")
10+
_TENSORFLOW_AVAILABLE = module_available("tensorflow")
11+
_KERAS_AVAILABLE = RequirementCache("tensorflow >=2.0.0")
12+
13+
if _JOBLIB_AVAILABLE:
814
import joblib
9-
else:
10-
joblib = None
1115

1216

1317
def dump_pickle(model: Any, path: Union[str, Path]) -> None:
@@ -17,7 +21,7 @@ def dump_pickle(model: Any, path: Union[str, Path]) -> None:
1721
model: The model to be pickled.
1822
path: The path where the model will be saved.
1923
"""
20-
if joblib is not None:
24+
if _JOBLIB_AVAILABLE:
2125
joblib.dump(model, filename=path, compress=7)
2226
else:
2327
with open(path, "wb") as fp:
@@ -33,7 +37,7 @@ def load_pickle(path: Union[str, Path]) -> Any:
3337
Returns:
3438
The unpickled model.
3539
"""
36-
if joblib is not None:
40+
if _JOBLIB_AVAILABLE:
3741
return joblib.load(path)
3842
with open(path, "rb") as fp:
3943
return pickle.load(fp)

tests/integrations/test_cloud.py

Lines changed: 31 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,10 +10,11 @@
1010
from lightning_sdk.lightning_cloud.rest_client import GridRestClient
1111
from lightning_sdk.utils.resolve import _resolve_teamspace
1212

13-
from litmodels import download_model, upload_model
13+
from litmodels import download_model, load_model, upload_model
1414
from litmodels.integrations.duplicate import duplicate_hf_model
1515
from litmodels.integrations.mixins import PickleRegistryMixin, PyTorchRegistryMixin
1616
from litmodels.io.cloud import _list_available_teamspaces
17+
from litmodels.io.utils import _KERAS_AVAILABLE
1718
from tests.integrations import (
1819
_SKIP_IF_LIGHTNING_BELLOW_2_5_1,
1920
_SKIP_IF_PYTORCHLIGHTNING_BELLOW_2_5_1,
@@ -315,3 +316,32 @@ def test_list_available_teamspaces():
315316
assert len(teams) > 0
316317
# using sanitized teamspace name
317318
assert f"{LIT_ORG}/oss-litmodels" in teams
319+
320+
321+
@pytest.mark.cloud
322+
@pytest.mark.skipif(
323+
not _KERAS_AVAILABLE,
324+
reason="TensorFlow Keras is not supported on Windows for now.",
325+
)
326+
def test_save_load_tensorflow_keras(tmp_path):
327+
from tensorflow import keras
328+
329+
# Define the model
330+
model = keras.Sequential([
331+
keras.layers.Dense(10, input_shape=(784,), name="dense_1"),
332+
keras.layers.Dense(10, name="dense_2"),
333+
])
334+
335+
# Compile the model
336+
model.compile(optimizer="adam", loss="categorical_crossentropy")
337+
338+
# model name with random hash
339+
teamspace, org_team, model_name = _prepare_variables("tf-keras")
340+
upload_model(f"{org_team}/{model_name}", model=model)
341+
342+
# Load the model
343+
model_ = load_model(f"{org_team}/{model_name}", download_dir=str(tmp_path))
344+
345+
# validate the model
346+
assert isinstance(model_, type(model))
347+
_cleanup_model(teamspace, model_name, expected_num_versions=1)

tests/test_io_cloud.py

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
import litmodels
1212
from litmodels import download_model, load_model, upload_model
1313
from litmodels.io import upload_model_files
14+
from litmodels.io.utils import _KERAS_AVAILABLE
1415

1516

1617
@pytest.mark.parametrize(
@@ -109,3 +110,31 @@ def test_load_model_torch_jit(mock_download_model, tmp_path):
109110
name="org-name/teamspace/model-name", download_dir=str(tmp_path), progress_bar=True
110111
)
111112
assert isinstance(model, torch.jit.ScriptModule)
113+
114+
115+
@pytest.mark.skipif(not _KERAS_AVAILABLE, reason="TensorFlow/Keras is not available")
116+
@mock.patch("litmodels.io.cloud.sdk_download_model")
117+
def test_load_model_tf_keras(mock_download_model, tmp_path):
118+
from tensorflow import keras
119+
120+
# create a dummy model file
121+
model_file = tmp_path / "dummy_model.keras"
122+
# Define the model
123+
model = keras.Sequential([
124+
keras.layers.Dense(10, input_shape=(784,), name="dense_1"),
125+
keras.layers.Dense(10, name="dense_2"),
126+
])
127+
model.compile(optimizer="adam", loss="categorical_crossentropy")
128+
model.save(model_file)
129+
# prepare mocked SDK download function
130+
mock_download_model.return_value = [str(model_file.name)]
131+
132+
# The lit-logger function is just a wrapper around the SDK function
133+
model = load_model(
134+
name="org-name/teamspace/model-name",
135+
download_dir=str(tmp_path),
136+
)
137+
mock_download_model.assert_called_once_with(
138+
name="org-name/teamspace/model-name", download_dir=str(tmp_path), progress_bar=True
139+
)
140+
assert isinstance(model, keras.models.Model)

0 commit comments

Comments
 (0)