Skip to content

Commit e29d0ef

Browse files
sachinprasadhsHyperPShertschuhdivyashreepathihalli0xManan
authored
Version bump and cherry picks for 3.13.2 (#22080)
* Fix DoS via malicious HDF5 dataset metadata in KerasFileEditor (#21880) * Fix DoS via malicious HDF5 dataset metadata in KerasFileEditor * Refactor: move MAX_BYTES constant outside loop per review feedback * Fix: harden HDF5 dataset metadata validation in KerasFileEditor * Do not allow external links in HDF5 files. (#22057) Keras never uses this feature. - verify that we get H5 Groups when expected, otherwise, merely by doing `[key]` we may be loading an external Dataset. - verify that the H5 Datasets are not external links and fail if they are. - remove unused methods `items` and `values` in `H5IOStore` and `ShardedH5IOStore`. They are not used, the implementation of `MutableMapping` was incomplete anyway and these methods we return unverified Datasets. - fixed logic related to `failed_saveables` in `load_state`. - preserve the order of keys in the implementation of `ShardedH5IOStore.keys()`. * Set mutable to True by default in nnx_metadata (#22074) * Disallow TFSMLayer deserialization in safe_mode to prevent external SavedModel execution (#22035) * Implement safe mode checks in TFSMLayer Added safe mode checks for loading TFSMLayer from external SavedModels. * Update keras/src/export/tfsm_layer.py Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com> * Align logic with __init__ method for robust checks Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com> * Fix indentation and formatting in tfsm_layer.py * Add setup method to enable unsafe deserialization Enable unsafe deserialization for TFSM Layer tests. * Update TFSMLayer initialization in tests * Fix import for TFSMLayer in tfsm_layer_test.py * Remove safe_mode check from TFSMLayer.__init__() The safe_mode check should only be in from_config(), not __init__(). Direct instantiation (TFSMLayer(filepath=...)) is a legitimate use case where the user explicitly creates the layer. The security concern is only during deserialization of untrusted .keras files, which goes through from_config(). This allows attackers to create malicious .keras files while still blocking victims from loading them with safe_mode=True. * Implement tests for TFSMLayer safe mode functionality Add comprehensive tests for TFSMLayer safe_mode behavior: - test_safe_mode_direct_instantiation_allowed: Verifies direct TFSMLayer instantiation works as expected - test_safe_mode_from_config_blocked: Verifies from_config() raises ValueError when safe_mode=True - test_safe_mode_from_config_allowed_when_disabled: Verifies from_config() works with safe_mode=False - test_safe_mode_model_loading_blocked: Tests the full attack scenario where loading a .keras file with safe_mode=True is blocked * Clarify test docstrings in tfsm_layer_test.py Updated test docstrings for clarity on instantiation and loading behavior. * Invoke model with random input in tfsm_layer tests Added model invocation with random input to tests for TFSMLayer. * Set safe_mode default to True in from_config method * Update tfsm_layer_test.py * Update tfsm_layer_test.py * Update tfsm_layer_test.py to original * New test case tfsm_layer_test.py * Update Comments tfsm_layer.py * Update tfsm_layer_test.py * Update tfsm_layer.py * Update tfsm_layer.py to remove ruff errors * Update tfsm_layer.py * Update tfsm_layer_test.py * Update tfsm_layer.py * Update tfsm_layer.py * Update tfsm_layer_test.py * Update tfsm_layer.py format fix Changes in format * Update tfsm_layer.py * Update tfsm_layer_test.py * Update tfsm_layer.py * Update tfsm_layer_test.py * Fixes unnecessary changes tfsm_layer.py * Added new test case tfsm_layer_test.py * Set `safe_mode=None` in `from_config`, which fixes the unit tests. Also re-added empty lines. * Remove unneeded `custom_objects` in unit tests. --------- Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com> Co-authored-by: Fabien Hertschuh <1091026+hertschuh@users.noreply.github.com> * patch release 3.12.2 changes --------- Co-authored-by: sarvesh patil <103917093+HyperPS@users.noreply.github.com> Co-authored-by: hertschuh <1091026+hertschuh@users.noreply.github.com> Co-authored-by: Divyashree Sreepathihalli <divyashreepathihalli@gmail.com> Co-authored-by: Manan Patel <70314133+0xManan@users.noreply.github.com> Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>
1 parent 8914427 commit e29d0ef

File tree

7 files changed

+201
-76
lines changed

7 files changed

+201
-76
lines changed

keras/src/backend/jax/core.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -98,7 +98,7 @@ def __init__(
9898
):
9999
# Ensure 'mutable' is in nnx_metadata, but explicit 'mutable'
100100
# param takes precedence.
101-
nnx_metadata["mutable"] = trainable if mutable is None else mutable
101+
nnx_metadata["mutable"] = True if mutable is None else mutable
102102

103103
# First, initialize a basic nnx.Variable with a dummy value
104104
# This sets up the NNX variable structure

keras/src/export/tfsm_layer.py

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
from keras.src import layers
33
from keras.src.api_export import keras_export
44
from keras.src.export.saved_model import _list_variables_used_by_fns
5+
from keras.src.saving import serialization_lib
56
from keras.src.utils.module_utils import tensorflow as tf
67

78

@@ -146,3 +147,36 @@ def get_config(self):
146147
"call_training_endpoint": self.call_training_endpoint,
147148
}
148149
return {**base_config, **config}
150+
151+
@classmethod
152+
def from_config(cls, config, custom_objects=None, safe_mode=None):
153+
"""Creates a TFSMLayer from its config.
154+
Args:
155+
config: A Python dictionary, typically the output of `get_config`.
156+
custom_objects: Optional dictionary mapping names to custom objects.
157+
safe_mode: Boolean, whether to disallow loading TFSMLayer.
158+
When `safe_mode=True`, loading is disallowed because TFSMLayer
159+
loads external SavedModels that may contain attacker-controlled
160+
executable graph code. Defaults to `True`.
161+
Returns:
162+
A TFSMLayer instance.
163+
"""
164+
# Follow the same pattern as Lambda layer for safe_mode handling
165+
effective_safe_mode = (
166+
safe_mode
167+
if safe_mode is not None
168+
else serialization_lib.in_safe_mode()
169+
)
170+
171+
if effective_safe_mode is not False:
172+
raise ValueError(
173+
"Requested the deserialization of a `TFSMLayer`, which "
174+
"loads an external SavedModel. This carries a potential risk "
175+
"of arbitrary code execution and thus it is disallowed by "
176+
"default. If you trust the source of the artifact, you can "
177+
"override this error by passing `safe_mode=False` to the "
178+
"loading function, or calling "
179+
"`keras.config.enable_unsafe_deserialization()."
180+
)
181+
182+
return cls(**config)

keras/src/export/tfsm_layer_test.py

Lines changed: 32 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -114,19 +114,48 @@ def test_serialization(self):
114114

115115
# Test reinstantiation from config
116116
config = reloaded_layer.get_config()
117-
rereloaded_layer = tfsm_layer.TFSMLayer.from_config(config)
117+
rereloaded_layer = tfsm_layer.TFSMLayer.from_config(
118+
config, safe_mode=False
119+
)
118120
self.assertAllClose(rereloaded_layer(ref_input), ref_output, atol=1e-7)
119121

120122
# Test whole model saving with reloaded layer inside
121123
model = models.Sequential([reloaded_layer])
122124
temp_model_filepath = os.path.join(self.get_temp_dir(), "m.keras")
123125
model.save(temp_model_filepath, save_format="keras_v3")
124126
reloaded_model = saving_lib.load_model(
125-
temp_model_filepath,
126-
custom_objects={"TFSMLayer": tfsm_layer.TFSMLayer},
127+
temp_model_filepath, safe_mode=False
127128
)
128129
self.assertAllClose(reloaded_model(ref_input), ref_output, atol=1e-7)
129130

131+
def test_safe_mode_blocks_model_loading(self):
132+
temp_filepath = os.path.join(self.get_temp_dir(), "exported_model")
133+
134+
# Create and export a model
135+
model = get_model()
136+
model(tf.random.normal((1, 10)))
137+
saved_model.export_saved_model(model, temp_filepath)
138+
139+
# Wrap SavedModel in TFSMLayer and save as .keras
140+
reloaded_layer = tfsm_layer.TFSMLayer(temp_filepath)
141+
wrapper_model = models.Sequential([reloaded_layer])
142+
143+
model_path = os.path.join(self.get_temp_dir(), "tfsm_model.keras")
144+
wrapper_model.save(model_path)
145+
146+
# Default safe_mode=True should block loading
147+
with self.assertRaisesRegex(
148+
ValueError,
149+
"arbitrary code execution",
150+
):
151+
saving_lib.load_model(model_path)
152+
153+
# Explicit opt-out should allow loading
154+
loaded_model = saving_lib.load_model(model_path, safe_mode=False)
155+
156+
x = tf.random.normal((2, 10))
157+
self.assertAllClose(loaded_model(x), wrapper_model(x))
158+
130159
def test_errors(self):
131160
# Test missing call endpoint
132161
temp_filepath = os.path.join(self.get_temp_dir(), "exported_model")

keras/src/saving/file_editor.py

Lines changed: 87 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -455,33 +455,114 @@ def resave_weights(self, filepath):
455455
def _extract_weights_from_store(self, data, metadata=None, inner_path=""):
456456
metadata = metadata or {}
457457

458+
# ------------------------------------------------------
459+
# Collect metadata for this HDF5 group
460+
# ------------------------------------------------------
458461
object_metadata = {}
459462
for k, v in data.attrs.items():
460463
object_metadata[k] = v
461464
if object_metadata:
462465
metadata[inner_path] = object_metadata
463466

464467
result = collections.OrderedDict()
468+
469+
# ------------------------------------------------------
470+
# Iterate over all keys in this HDF5 group
471+
# ------------------------------------------------------
465472
for key in data.keys():
466-
inner_path = f"{inner_path}/{key}"
473+
# IMPORTANT:
474+
# Never mutate inner_path; use local variable.
475+
current_inner_path = f"{inner_path}/{key}"
467476
value = data[key]
477+
478+
# ------------------------------------------------------
479+
# CASE 1 — HDF5 GROUP → RECURSE
480+
# ------------------------------------------------------
468481
if isinstance(value, h5py.Group):
482+
# Skip empty groups
469483
if len(value) == 0:
470484
continue
485+
486+
# Skip empty "vars" groups
471487
if "vars" in value.keys() and len(value["vars"]) == 0:
472488
continue
473489

474-
if hasattr(value, "keys"):
490+
# Recurse into "vars" subgroup when present
475491
if "vars" in value.keys():
476492
result[key], metadata = self._extract_weights_from_store(
477-
value["vars"], metadata=metadata, inner_path=inner_path
493+
value["vars"],
494+
metadata=metadata,
495+
inner_path=current_inner_path,
478496
)
479497
else:
498+
# Recurse normally
480499
result[key], metadata = self._extract_weights_from_store(
481-
value, metadata=metadata, inner_path=inner_path
500+
value,
501+
metadata=metadata,
502+
inner_path=current_inner_path,
482503
)
483-
else:
484-
result[key] = value[()]
504+
505+
continue # finished processing this key
506+
507+
# ------------------------------------------------------
508+
# CASE 2 — HDF5 DATASET → SAFE LOADING
509+
# ------------------------------------------------------
510+
511+
# Skip any objects that are not proper datasets
512+
if not isinstance(value, h5py.Dataset):
513+
continue
514+
515+
if value.external:
516+
raise ValueError(
517+
"Not allowed: H5 file Dataset with external links: "
518+
f"{value.external}"
519+
)
520+
521+
shape = value.shape
522+
dtype = value.dtype
523+
524+
# ------------------------------------------------------
525+
# Validate SHAPE (avoid malformed / malicious metadata)
526+
# ------------------------------------------------------
527+
528+
# No negative dimensions
529+
if any(dim < 0 for dim in shape):
530+
raise ValueError(
531+
"Malformed HDF5 dataset shape encountered in .keras file; "
532+
"negative dimension detected."
533+
)
534+
535+
# Prevent absurdly high-rank tensors
536+
if len(shape) > 64:
537+
raise ValueError(
538+
"Malformed HDF5 dataset shape encountered in .keras file; "
539+
"tensor rank exceeds safety limit."
540+
)
541+
542+
# Safe product computation (Python int is unbounded)
543+
num_elems = int(np.prod(shape))
544+
545+
# ------------------------------------------------------
546+
# Validate TOTAL memory size
547+
# ------------------------------------------------------
548+
MAX_BYTES = 1 << 32 # 4 GiB
549+
550+
size_bytes = num_elems * dtype.itemsize
551+
552+
if size_bytes > MAX_BYTES:
553+
raise ValueError(
554+
f"HDF5 dataset too large to load safely "
555+
f"({size_bytes} bytes; limit is {MAX_BYTES})."
556+
)
557+
558+
# ------------------------------------------------------
559+
# SAFE — load dataset (guaranteed ≤ 4 GiB)
560+
# ------------------------------------------------------
561+
result[key] = value[()]
562+
563+
# ------------------------------------------------------
564+
# Return final tree and metadata
565+
# ------------------------------------------------------
485566
return result, metadata
486567

487568
def _generate_filepath_info(self, rich_style=False):

0 commit comments

Comments
 (0)