Skip to content
This repository was archived by the owner on Mar 10, 2026. It is now read-only.

Commit a354797

Browse files
Increase test coverage + Fix save_model_to_hdf5 + Improve is_remote_path + Fix is_remote_path (#900)
* Increase test coverage in `saving` * Add FAILED tests TODO * Add tests for `LambdaCallback` * Add tests for `LambdaCallback` * Add test for saving_api.py#L96 * Increase test coverage in `saving` * Increase test coverage * refines the logic `os.makedirs` +Increase tests * Increase test coverage * Increase test coverage * More tests file_utils_test.py+fix bug `rmtree` * More tests `file_utils_test` + fix bug `rmtree` * More tests file_utils_test + fix bug rmtree * Increase test coverage * add tests to `lambda_callback_test` * Add tests in file_utils_test.py * Add tests in file_utils_test.py * Add more tests `file_utils_test` * add class TestValidateFile * Add tests for `TestIsRemotePath` * Add tests in file_utils_test.py * Add tests in file_utils_test.py * Add tests in file_utils_test.py * Add tests in `file_utils_test.py` * fix `is_remote_path` * improve `is_remote_path` * Add test for `raise_if_no_gfile_raises` * Add tests for file_utils.py * Add tests in `saving_api_test.py` * Add tests `saving_api_test.py` * Add tests saving_api_test.py * Add tests in `saving_api_test.py` * Add test `test_directory_creation_on_save` * Add test `legacy_h5_format_test.py` * Flake8 for `LambdaCallbackTest` * use `get_model` and `self.get_temp_dir` * Fix format * Improve `is_remote_path` + Add tests * Fix `is_remote_path`
1 parent 5af4344 commit a354797

File tree

6 files changed

+947
-104
lines changed

6 files changed

+947
-104
lines changed

keras_core/callbacks/lambda_callback_test.py

Lines changed: 123 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -12,10 +12,11 @@
1212

1313
class LambdaCallbackTest(testing.TestCase):
1414
@pytest.mark.requires_trainable_backend
15-
def test_LambdaCallback(self):
16-
BATCH_SIZE = 4
15+
def test_lambda_callback(self):
16+
"""Test standard LambdaCallback functionalities with training."""
17+
batch_size = 4
1718
model = Sequential(
18-
[layers.Input(shape=(2,), batch_size=BATCH_SIZE), layers.Dense(1)]
19+
[layers.Input(shape=(2,), batch_size=batch_size), layers.Dense(1)]
1920
)
2021
model.compile(
2122
optimizer=optimizers.SGD(), loss=losses.MeanSquaredError()
@@ -34,7 +35,7 @@ def test_LambdaCallback(self):
3435
model.fit(
3536
x,
3637
y,
37-
batch_size=BATCH_SIZE,
38+
batch_size=batch_size,
3839
validation_split=0.2,
3940
callbacks=[lambda_log_callback],
4041
epochs=5,
@@ -44,3 +45,121 @@ def test_LambdaCallback(self):
4445
self.assertTrue(any("on_epoch_begin" in log for log in logs.output))
4546
self.assertTrue(any("on_epoch_end" in log for log in logs.output))
4647
self.assertTrue(any("on_train_end" in log for log in logs.output))
48+
49+
@pytest.mark.requires_trainable_backend
50+
def test_lambda_callback_with_batches(self):
51+
"""Test LambdaCallback's behavior with batch-level callbacks."""
52+
batch_size = 4
53+
model = Sequential(
54+
[layers.Input(shape=(2,), batch_size=batch_size), layers.Dense(1)]
55+
)
56+
model.compile(
57+
optimizer=optimizers.SGD(), loss=losses.MeanSquaredError()
58+
)
59+
x = np.random.randn(16, 2)
60+
y = np.random.randn(16, 1)
61+
lambda_log_callback = callbacks.LambdaCallback(
62+
on_train_batch_begin=lambda batch, logs: logging.warning(
63+
"on_train_batch_begin"
64+
),
65+
on_train_batch_end=lambda batch, logs: logging.warning(
66+
"on_train_batch_end"
67+
),
68+
)
69+
with self.assertLogs(level="WARNING") as logs:
70+
model.fit(
71+
x,
72+
y,
73+
batch_size=batch_size,
74+
validation_split=0.2,
75+
callbacks=[lambda_log_callback],
76+
epochs=5,
77+
verbose=0,
78+
)
79+
self.assertTrue(
80+
any("on_train_batch_begin" in log for log in logs.output)
81+
)
82+
self.assertTrue(
83+
any("on_train_batch_end" in log for log in logs.output)
84+
)
85+
86+
@pytest.mark.requires_trainable_backend
87+
def test_lambda_callback_with_kwargs(self):
88+
"""Test LambdaCallback's behavior with custom defined callback."""
89+
batch_size = 4
90+
model = Sequential(
91+
[layers.Input(shape=(2,), batch_size=batch_size), layers.Dense(1)]
92+
)
93+
model.compile(
94+
optimizer=optimizers.SGD(), loss=losses.MeanSquaredError()
95+
)
96+
x = np.random.randn(16, 2)
97+
y = np.random.randn(16, 1)
98+
model.fit(
99+
x, y, batch_size=batch_size, epochs=1, verbose=0
100+
) # Train briefly for evaluation to work.
101+
102+
def custom_on_test_begin(logs):
103+
logging.warning("custom_on_test_begin_executed")
104+
105+
lambda_log_callback = callbacks.LambdaCallback(
106+
on_test_begin=custom_on_test_begin
107+
)
108+
with self.assertLogs(level="WARNING") as logs:
109+
model.evaluate(
110+
x,
111+
y,
112+
batch_size=batch_size,
113+
callbacks=[lambda_log_callback],
114+
verbose=0,
115+
)
116+
self.assertTrue(
117+
any(
118+
"custom_on_test_begin_executed" in log
119+
for log in logs.output
120+
)
121+
)
122+
123+
@pytest.mark.requires_trainable_backend
124+
def test_lambda_callback_no_args(self):
125+
"""Test initializing LambdaCallback without any arguments."""
126+
lambda_callback = callbacks.LambdaCallback()
127+
self.assertIsInstance(lambda_callback, callbacks.LambdaCallback)
128+
129+
@pytest.mark.requires_trainable_backend
130+
def test_lambda_callback_with_additional_kwargs(self):
131+
"""Test initializing LambdaCallback with non-predefined kwargs."""
132+
133+
def custom_callback(logs):
134+
pass
135+
136+
lambda_callback = callbacks.LambdaCallback(
137+
custom_method=custom_callback
138+
)
139+
self.assertTrue(hasattr(lambda_callback, "custom_method"))
140+
141+
@pytest.mark.requires_trainable_backend
142+
def test_lambda_callback_during_prediction(self):
143+
"""Test LambdaCallback's functionality during model prediction."""
144+
batch_size = 4
145+
model = Sequential(
146+
[layers.Input(shape=(2,), batch_size=batch_size), layers.Dense(1)]
147+
)
148+
model.compile(
149+
optimizer=optimizers.SGD(), loss=losses.MeanSquaredError()
150+
)
151+
x = np.random.randn(16, 2)
152+
153+
def custom_on_predict_begin(logs):
154+
logging.warning("on_predict_begin_executed")
155+
156+
lambda_callback = callbacks.LambdaCallback(
157+
on_predict_begin=custom_on_predict_begin
158+
)
159+
with self.assertLogs(level="WARNING") as logs:
160+
model.predict(
161+
x, batch_size=batch_size, callbacks=[lambda_callback], verbose=0
162+
)
163+
self.assertTrue(
164+
any("on_predict_begin_executed" in log for log in logs.output)
165+
)

keras_core/legacy/saving/legacy_h5_format.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -37,10 +37,9 @@ def save_model_to_hdf5(model, filepath, overwrite=True, include_optimizer=True):
3737
if not proceed:
3838
return
3939

40-
# Try creating dir if not exist
4140
dirpath = os.path.dirname(filepath)
42-
if not os.path.exists(dirpath):
43-
os.path.makedirs(dirpath)
41+
if dirpath and not os.path.exists(dirpath):
42+
os.makedirs(dirpath, exist_ok=True)
4443

4544
f = h5py.File(filepath, mode="w")
4645
opened_new_file = True

keras_core/legacy/saving/legacy_h5_format_test.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -481,3 +481,19 @@ def call(self, x):
481481

482482
# Compare output
483483
self.assertAllClose(ref_output, output, atol=1e-5)
484+
485+
486+
@pytest.mark.requires_trainable_backend
487+
class DirectoryCreationTest(testing.TestCase):
488+
def test_directory_creation_on_save(self):
489+
"""Test if directory is created on model save."""
490+
model = get_sequential_model(keras_core)
491+
nested_dirpath = os.path.join(
492+
self.get_temp_dir(), "dir1", "dir2", "dir3"
493+
)
494+
filepath = os.path.join(nested_dirpath, "model.h5")
495+
self.assertFalse(os.path.exists(nested_dirpath))
496+
legacy_h5_format.save_model_to_hdf5(model, filepath)
497+
self.assertTrue(os.path.exists(nested_dirpath))
498+
loaded_model = legacy_h5_format.load_model_from_hdf5(filepath)
499+
self.assertEqual(model.to_json(), loaded_model.to_json())
Lines changed: 178 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,178 @@
1+
import os
2+
import unittest.mock as mock
3+
4+
import numpy as np
5+
from absl import logging
6+
7+
from keras_core import layers
8+
from keras_core.models import Sequential
9+
from keras_core.saving import saving_api
10+
from keras_core.testing import test_case
11+
12+
13+
class SaveModelTests(test_case.TestCase):
14+
def get_model(self):
15+
return Sequential(
16+
[
17+
layers.Dense(5, input_shape=(3,)),
18+
layers.Softmax(),
19+
]
20+
)
21+
22+
def test_basic_saving(self):
23+
"""Test basic model saving and loading."""
24+
model = self.get_model()
25+
filepath = os.path.join(self.get_temp_dir(), "test_model.keras")
26+
saving_api.save_model(model, filepath)
27+
28+
loaded_model = saving_api.load_model(filepath)
29+
x = np.random.uniform(size=(10, 3))
30+
self.assertTrue(np.allclose(model.predict(x), loaded_model.predict(x)))
31+
32+
def test_invalid_save_format(self):
33+
"""Test deprecated save_format argument."""
34+
model = self.get_model()
35+
with self.assertRaisesRegex(
36+
ValueError, "The `save_format` argument is deprecated"
37+
):
38+
saving_api.save_model(model, "model.txt", save_format=True)
39+
40+
def test_unsupported_arguments(self):
41+
"""Test unsupported argument during model save."""
42+
model = self.get_model()
43+
filepath = os.path.join(self.get_temp_dir(), "test_model.keras")
44+
with self.assertRaisesRegex(
45+
ValueError, r"The following argument\(s\) are not supported"
46+
):
47+
saving_api.save_model(model, filepath, random_arg=True)
48+
49+
def test_save_h5_format(self):
50+
"""Test saving model in h5 format."""
51+
model = self.get_model()
52+
filepath_h5 = os.path.join(self.get_temp_dir(), "test_model.h5")
53+
saving_api.save_model(model, filepath_h5)
54+
self.assertTrue(os.path.exists(filepath_h5))
55+
os.remove(filepath_h5)
56+
57+
def test_save_unsupported_extension(self):
58+
"""Test saving model with unsupported extension."""
59+
model = self.get_model()
60+
with self.assertRaisesRegex(
61+
ValueError, "Invalid filepath extension for saving"
62+
):
63+
saving_api.save_model(model, "model.png")
64+
65+
66+
class LoadModelTests(test_case.TestCase):
67+
def get_model(self):
68+
return Sequential(
69+
[
70+
layers.Dense(5, input_shape=(3,)),
71+
layers.Softmax(),
72+
]
73+
)
74+
75+
def test_basic_load(self):
76+
"""Test basic model loading."""
77+
model = self.get_model()
78+
filepath = os.path.join(self.get_temp_dir(), "test_model.keras")
79+
saving_api.save_model(model, filepath)
80+
81+
loaded_model = saving_api.load_model(filepath)
82+
x = np.random.uniform(size=(10, 3))
83+
self.assertTrue(np.allclose(model.predict(x), loaded_model.predict(x)))
84+
85+
def test_load_unsupported_format(self):
86+
"""Test loading model with unsupported format."""
87+
with self.assertRaisesRegex(ValueError, "File format not supported"):
88+
saving_api.load_model("model.pkl")
89+
90+
def test_load_keras_not_zip(self):
91+
"""Test loading keras file that's not a zip."""
92+
with self.assertRaisesRegex(ValueError, "File not found"):
93+
saving_api.load_model("not_a_zip.keras")
94+
95+
def test_load_h5_format(self):
96+
"""Test loading model in h5 format."""
97+
model = self.get_model()
98+
filepath_h5 = os.path.join(self.get_temp_dir(), "test_model.h5")
99+
saving_api.save_model(model, filepath_h5)
100+
loaded_model = saving_api.load_model(filepath_h5)
101+
x = np.random.uniform(size=(10, 3))
102+
self.assertTrue(np.allclose(model.predict(x), loaded_model.predict(x)))
103+
os.remove(filepath_h5)
104+
105+
def test_load_model_with_custom_objects(self):
106+
"""Test loading model with custom objects."""
107+
108+
class CustomLayer(layers.Layer):
109+
def call(self, inputs):
110+
return inputs
111+
112+
model = Sequential([CustomLayer(input_shape=(3,))])
113+
filepath = os.path.join(self.get_temp_dir(), "custom_model.keras")
114+
model.save(filepath)
115+
loaded_model = saving_api.load_model(
116+
filepath, custom_objects={"CustomLayer": CustomLayer}
117+
)
118+
self.assertIsInstance(loaded_model.layers[0], CustomLayer)
119+
os.remove(filepath)
120+
121+
122+
class LoadWeightsTests(test_case.TestCase):
123+
def get_model(self):
124+
return Sequential(
125+
[
126+
layers.Dense(5, input_shape=(3,)),
127+
layers.Softmax(),
128+
]
129+
)
130+
131+
def test_load_keras_weights(self):
132+
"""Test loading keras weights."""
133+
model = self.get_model()
134+
filepath = os.path.join(self.get_temp_dir(), "test_weights.weights.h5")
135+
model.save_weights(filepath)
136+
original_weights = model.get_weights()
137+
model.load_weights(filepath)
138+
loaded_weights = model.get_weights()
139+
for orig, loaded in zip(original_weights, loaded_weights):
140+
self.assertTrue(np.array_equal(orig, loaded))
141+
142+
def test_load_h5_weights_by_name(self):
143+
"""Test loading h5 weights by name."""
144+
model = self.get_model()
145+
filepath = os.path.join(self.get_temp_dir(), "test_weights.weights.h5")
146+
model.save_weights(filepath)
147+
with self.assertRaisesRegex(ValueError, "Invalid keyword arguments"):
148+
model.load_weights(filepath, by_name=True)
149+
150+
def test_load_weights_invalid_extension(self):
151+
"""Test loading weights with unsupported extension."""
152+
model = self.get_model()
153+
with self.assertRaisesRegex(ValueError, "File format not supported"):
154+
model.load_weights("invalid_extension.pkl")
155+
156+
157+
class SaveModelTestsWarning(test_case.TestCase):
158+
def get_model(self):
159+
return Sequential(
160+
[
161+
layers.Dense(5, input_shape=(3,)),
162+
layers.Softmax(),
163+
]
164+
)
165+
166+
def test_h5_deprecation_warning(self):
167+
"""Test deprecation warning for h5 format."""
168+
model = self.get_model()
169+
filepath = os.path.join(self.get_temp_dir(), "test_model.h5")
170+
171+
with mock.patch.object(logging, "warning") as mock_warn:
172+
saving_api.save_model(model, filepath)
173+
mock_warn.assert_called_once_with(
174+
"You are saving your model as an HDF5 file via `model.save()`. "
175+
"This file format is considered legacy. "
176+
"We recommend using instead the native Keras format, "
177+
"e.g. `model.save('my_model.keras')`."
178+
)

keras_core/utils/file_utils.py

Lines changed: 14 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -386,9 +386,19 @@ def validate_file(fpath, file_hash, algorithm="auto", chunk_size=65535):
386386

387387

388388
def is_remote_path(filepath):
389-
"""Returns `True` for paths that represent a remote GCS location."""
390-
# TODO: improve generality.
391-
if re.match(r"^(/cns|/cfs|/gcs|.*://).*$", str(filepath)):
389+
"""
390+
Determines if a given filepath indicates a remote location.
391+
392+
This function checks if the filepath represents a known remote pattern
393+
such as GCS (`/gcs`), CNS (`/cns`), CFS (`/cfs`), HDFS (`/hdfs`)
394+
395+
Args:
396+
filepath (str): The path to be checked.
397+
398+
Returns:
399+
bool: True if the filepath is a recognized remote path, otherwise False
400+
"""
401+
if re.match(r"^(/cns|/cfs|/gcs|/hdfs|.*://).*$", str(filepath)):
392402
return True
393403
return False
394404

@@ -445,7 +455,7 @@ def rmtree(path):
445455
return gfile.rmtree(path)
446456
else:
447457
_raise_if_no_gfile(path)
448-
return shutil.rmtree
458+
return shutil.rmtree(path)
449459

450460

451461
def listdir(path):

0 commit comments

Comments
 (0)