Skip to content

Commit ac5c97f

Browse files
authored
Propagate safe_mode flag to legacy h5 loading code. (#21602)
Also: - made various error messages related to `safe_mode` more consistent - removed no-op renaming code in legacy saving - uncommented unit tests in `serialization_lib_test.py`
1 parent 79413bc commit ac5c97f

File tree

10 files changed

+65
-75
lines changed

10 files changed

+65
-75
lines changed

keras/src/layers/core/lambda_layer.py

Lines changed: 9 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -167,14 +167,15 @@ def _serialize_function_to_config(self, fn):
167167
)
168168

169169
@staticmethod
170-
def _raise_for_lambda_deserialization(arg_name, safe_mode):
170+
def _raise_for_lambda_deserialization(safe_mode):
171171
if safe_mode:
172172
raise ValueError(
173-
f"The `{arg_name}` of this `Lambda` layer is a Python lambda. "
174-
"Deserializing it is unsafe. If you trust the source of the "
175-
"config artifact, you can override this error "
176-
"by passing `safe_mode=False` "
177-
"to `from_config()`, or calling "
173+
"Requested the deserialization of a `Lambda` layer whose "
174+
"`function` is a Python lambda. This carries a potential risk "
175+
"of arbitrary code execution and thus it is disallowed by "
176+
"default. If you trust the source of the artifact, you can "
177+
"override this error by passing `safe_mode=False` to the "
178+
"loading function, or calling "
178179
"`keras.config.enable_unsafe_deserialization()."
179180
)
180181

@@ -187,7 +188,7 @@ def from_config(cls, config, custom_objects=None, safe_mode=None):
187188
and "class_name" in fn_config
188189
and fn_config["class_name"] == "__lambda__"
189190
):
190-
cls._raise_for_lambda_deserialization("function", safe_mode)
191+
cls._raise_for_lambda_deserialization(safe_mode)
191192
inner_config = fn_config["config"]
192193
fn = python_utils.func_load(
193194
inner_config["code"],
@@ -206,7 +207,7 @@ def from_config(cls, config, custom_objects=None, safe_mode=None):
206207
and "class_name" in fn_config
207208
and fn_config["class_name"] == "__lambda__"
208209
):
209-
cls._raise_for_lambda_deserialization("function", safe_mode)
210+
cls._raise_for_lambda_deserialization(safe_mode)
210211
inner_config = fn_config["config"]
211212
fn = python_utils.func_load(
212213
inner_config["code"],

keras/src/legacy/saving/legacy_h5_format.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
from keras.src.legacy.saving import saving_options
1212
from keras.src.legacy.saving import saving_utils
1313
from keras.src.saving import object_registration
14+
from keras.src.saving import serialization_lib
1415
from keras.src.utils import io_utils
1516

1617
try:
@@ -72,7 +73,9 @@ def save_model_to_hdf5(model, filepath, overwrite=True, include_optimizer=True):
7273
f.close()
7374

7475

75-
def load_model_from_hdf5(filepath, custom_objects=None, compile=True):
76+
def load_model_from_hdf5(
77+
filepath, custom_objects=None, compile=True, safe_mode=True
78+
):
7679
"""Loads a model saved via `save_model_to_hdf5`.
7780
7881
Args:
@@ -128,7 +131,9 @@ def load_model_from_hdf5(filepath, custom_objects=None, compile=True):
128131
model_config = model_config.decode("utf-8")
129132
model_config = json_utils.decode(model_config)
130133

131-
with saving_options.keras_option_scope(use_legacy_config=True):
134+
legacy_scope = saving_options.keras_option_scope(use_legacy_config=True)
135+
safe_mode_scope = serialization_lib.SafeModeScope(safe_mode)
136+
with legacy_scope, safe_mode_scope:
132137
model = saving_utils.model_from_config(
133138
model_config, custom_objects=custom_objects
134139
)

keras/src/legacy/saving/legacy_h5_format_test.py

Lines changed: 12 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -158,8 +158,13 @@ def test_saving_lambda(self):
158158

159159
temp_filepath = os.path.join(self.get_temp_dir(), "lambda_model.h5")
160160
legacy_h5_format.save_model_to_hdf5(model, temp_filepath)
161-
loaded = legacy_h5_format.load_model_from_hdf5(temp_filepath)
162161

162+
with self.assertRaisesRegex(ValueError, "arbitrary code execution"):
163+
legacy_h5_format.load_model_from_hdf5(temp_filepath)
164+
165+
loaded = legacy_h5_format.load_model_from_hdf5(
166+
temp_filepath, safe_mode=False
167+
)
163168
self.assertAllClose(mean, loaded.layers[1].arguments["mu"])
164169
self.assertAllClose(std, loaded.layers[1].arguments["std"])
165170

@@ -353,8 +358,13 @@ def test_saving_lambda(self):
353358

354359
temp_filepath = os.path.join(self.get_temp_dir(), "lambda_model.h5")
355360
tf_keras_model.save(temp_filepath)
356-
loaded = legacy_h5_format.load_model_from_hdf5(temp_filepath)
357361

362+
with self.assertRaisesRegex(ValueError, "arbitrary code execution"):
363+
legacy_h5_format.load_model_from_hdf5(temp_filepath)
364+
365+
loaded = legacy_h5_format.load_model_from_hdf5(
366+
temp_filepath, safe_mode=False
367+
)
358368
self.assertAllClose(mean, loaded.layers[1].arguments["mu"])
359369
self.assertAllClose(std, loaded.layers[1].arguments["std"])
360370

keras/src/legacy/saving/saving_utils.py

Lines changed: 0 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,3 @@
1-
import json
21
import threading
32

43
from absl import logging
@@ -81,10 +80,6 @@ def model_from_config(config, custom_objects=None):
8180
function_dict["config"]["closure"] = function_config[2]
8281
config["config"]["function"] = function_dict
8382

84-
# TODO(nkovela): Swap find and replace args during Keras 3.0 release
85-
# Replace keras refs with keras
86-
config = _find_replace_nested_dict(config, "keras.", "keras.")
87-
8883
return serialization.deserialize_keras_object(
8984
config,
9085
module_objects=MODULE_OBJECTS.ALL_OBJECTS,
@@ -231,13 +226,6 @@ def _deserialize_metric(metric_config):
231226
return metrics_module.deserialize(metric_config)
232227

233228

234-
def _find_replace_nested_dict(config, find, replace):
235-
dict_str = json.dumps(config)
236-
dict_str = dict_str.replace(find, replace)
237-
config = json.loads(dict_str)
238-
return config
239-
240-
241229
def _resolve_compile_arguments_compat(obj, obj_config, module):
242230
"""Resolves backwards compatibility issues with training config arguments.
243231

keras/src/legacy/saving/serialization.py

Lines changed: 0 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,6 @@
22

33
import contextlib
44
import inspect
5-
import json
65
import threading
76
import weakref
87

@@ -485,12 +484,6 @@ def deserialize(config, custom_objects=None):
485484
arg_spec = inspect.getfullargspec(cls.from_config)
486485
custom_objects = custom_objects or {}
487486

488-
# TODO(nkovela): Swap find and replace args during Keras 3.0 release
489-
# Replace keras refs with keras
490-
cls_config = _find_replace_nested_dict(
491-
cls_config, "keras.", "keras."
492-
)
493-
494487
if "custom_objects" in arg_spec.args:
495488
deserialized_obj = cls.from_config(
496489
cls_config,
@@ -565,10 +558,3 @@ def validate_config(config):
565558
def is_default(method):
566559
"""Check if a method is decorated with the `default` wrapper."""
567560
return getattr(method, "_is_default", False)
568-
569-
570-
def _find_replace_nested_dict(config, find, replace):
571-
dict_str = json.dumps(config)
572-
dict_str = dict_str.replace(find, replace)
573-
config = json.loads(dict_str)
574-
return config

keras/src/saving/saving_api.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -194,7 +194,10 @@ def load_model(filepath, custom_objects=None, compile=True, safe_mode=True):
194194
)
195195
if str(filepath).endswith((".h5", ".hdf5")):
196196
return legacy_h5_format.load_model_from_hdf5(
197-
filepath, custom_objects=custom_objects, compile=compile
197+
filepath,
198+
custom_objects=custom_objects,
199+
compile=compile,
200+
safe_mode=safe_mode,
198201
)
199202
elif str(filepath).endswith(".keras"):
200203
raise ValueError(

keras/src/saving/saving_lib_test.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -880,7 +880,7 @@ def test_safe_mode(self):
880880
]
881881
)
882882
model.save(temp_filepath)
883-
with self.assertRaisesRegex(ValueError, "Deserializing it is unsafe"):
883+
with self.assertRaisesRegex(ValueError, "arbitrary code execution"):
884884
model = saving_lib.load_model(temp_filepath)
885885
model = saving_lib.load_model(temp_filepath, safe_mode=False)
886886

keras/src/saving/serialization_lib.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -656,12 +656,12 @@ class ModifiedMeanSquaredError(keras.losses.MeanSquaredError):
656656
if config["class_name"] == "__lambda__":
657657
if safe_mode:
658658
raise ValueError(
659-
"Requested the deserialization of a `lambda` object. "
660-
"This carries a potential risk of arbitrary code execution "
661-
"and thus it is disallowed by default. If you trust the "
662-
"source of the saved model, you can pass `safe_mode=False` to "
663-
"the loading function in order to allow `lambda` loading, "
664-
"or call `keras.config.enable_unsafe_deserialization()`."
659+
"Requested the deserialization of a Python lambda. This "
660+
"carries a potential risk of arbitrary code execution and thus "
661+
"it is disallowed by default. If you trust the source of the "
662+
"artifact, you can override this error by passing "
663+
"`safe_mode=False` to the loading function, or calling "
664+
"`keras.config.enable_unsafe_deserialization()."
665665
)
666666
return python_utils.func_load(inner_config["value"])
667667
if tf is not None and config["class_name"] == "__typespec__":

keras/src/saving/serialization_lib_test.py

Lines changed: 22 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -175,31 +175,28 @@ def test_lambda_fn(self):
175175
_, new_obj, _ = self.roundtrip(obj, safe_mode=False)
176176
self.assertEqual(obj["activation"](3), new_obj["activation"](3))
177177

178-
# TODO
179-
# def test_lambda_layer(self):
180-
# lmbda = keras.layers.Lambda(lambda x: x**2)
181-
# with self.assertRaisesRegex(ValueError, "arbitrary code execution"):
182-
# self.roundtrip(lmbda, safe_mode=True)
183-
184-
# _, new_lmbda, _ = self.roundtrip(lmbda, safe_mode=False)
185-
# x = ops.random.normal((2, 2))
186-
# y1 = lmbda(x)
187-
# y2 = new_lmbda(x)
188-
# self.assertAllClose(y1, y2, atol=1e-5)
189-
190-
# def test_safe_mode_scope(self):
191-
# lmbda = keras.layers.Lambda(lambda x: x**2)
192-
# with serialization_lib.SafeModeScope(safe_mode=True):
193-
# with self.assertRaisesRegex(
194-
# ValueError, "arbitrary code execution"
195-
# ):
196-
# self.roundtrip(lmbda)
197-
# with serialization_lib.SafeModeScope(safe_mode=False):
198-
# _, new_lmbda, _ = self.roundtrip(lmbda)
199-
# x = ops.random.normal((2, 2))
200-
# y1 = lmbda(x)
201-
# y2 = new_lmbda(x)
202-
# self.assertAllClose(y1, y2, atol=1e-5)
178+
def test_lambda_layer(self):
179+
lmbda = keras.layers.Lambda(lambda x: x**2)
180+
with self.assertRaisesRegex(ValueError, "arbitrary code execution"):
181+
self.roundtrip(lmbda, safe_mode=True)
182+
183+
_, new_lmbda, _ = self.roundtrip(lmbda, safe_mode=False)
184+
x = ops.random.normal((2, 2))
185+
y1 = lmbda(x)
186+
y2 = new_lmbda(x)
187+
self.assertAllClose(y1, y2, atol=1e-5)
188+
189+
def test_safe_mode_scope(self):
190+
lmbda = keras.layers.Lambda(lambda x: x**2)
191+
with serialization_lib.SafeModeScope(safe_mode=True):
192+
with self.assertRaisesRegex(ValueError, "arbitrary code execution"):
193+
self.roundtrip(lmbda)
194+
with serialization_lib.SafeModeScope(safe_mode=False):
195+
_, new_lmbda, _ = self.roundtrip(lmbda)
196+
x = ops.random.normal((2, 2))
197+
y1 = lmbda(x)
198+
y2 = new_lmbda(x)
199+
self.assertAllClose(y1, y2, atol=1e-5)
203200

204201
@pytest.mark.requires_trainable_backend
205202
def test_dict_inputs_outputs(self):

keras/src/utils/torch_utils.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -172,10 +172,10 @@ def from_config(cls, config):
172172
"Requested the deserialization of a `torch.nn.Module` "
173173
"object via `torch.load()`. This carries a potential risk "
174174
"of arbitrary code execution and thus it is disallowed by "
175-
"default. If you trust the source of the saved model, you "
176-
"can pass `safe_mode=False` to the loading function in "
177-
"order to allow `torch.nn.Module` loading, or call "
178-
"`keras.config.enable_unsafe_deserialization()`."
175+
"default. If you trust the source of the artifact, you can "
176+
"override this error by passing `safe_mode=False` to the "
177+
"loading function, or calling "
178+
"`keras.config.enable_unsafe_deserialization()."
179179
)
180180

181181
# Decode the base64 string back to bytes

0 commit comments

Comments
 (0)