Skip to content

Commit f704c88

Browse files
sachinprasadhsHyperPShertschuh0xManangemini-code-assist[bot]
authored
3.12.1 cherry pick changes for patch release. (#22081)
* Fix DoS via malicious HDF5 dataset metadata in KerasFileEditor (#21880) * Fix DoS via malicious HDF5 dataset metadata in KerasFileEditor * Refactor: move MAX_BYTES constant outside loop per review feedback * Fix: harden HDF5 dataset metadata validation in KerasFileEditor * Do not allow external links in HDF5 files. (#22057) Keras never uses this feature. - verify that we get H5 Groups when expected, otherwise, merely by doing `[key]` we may be loading an external Dataset. - verify that the H5 Datasets are not external links and fail if they are. - remove unused methods `items` and `values` in `H5IOStore` and `ShardedH5IOStore`. They are not used, the implementation of `MutableMapping` was incomplete anyway and these methods we return unverified Datasets. - fixed logic related to `failed_saveables` in `load_state`. - preserve the order of keys in the implementation of `ShardedH5IOStore.keys()`. * Disallow TFSMLayer deserialization in safe_mode to prevent external SavedModel execution (#22035) * Implement safe mode checks in TFSMLayer Added safe mode checks for loading TFSMLayer from external SavedModels. * Update keras/src/export/tfsm_layer.py Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com> * Align logic with __init__ method for robust checks Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com> * Fix indentation and formatting in tfsm_layer.py * Add setup method to enable unsafe deserialization Enable unsafe deserialization for TFSM Layer tests. * Update TFSMLayer initialization in tests * Fix import for TFSMLayer in tfsm_layer_test.py * Remove safe_mode check from TFSMLayer.__init__() The safe_mode check should only be in from_config(), not __init__(). Direct instantiation (TFSMLayer(filepath=...)) is a legitimate use case where the user explicitly creates the layer. The security concern is only during deserialization of untrusted .keras files, which goes through from_config(). This allows attackers to create malicious .keras files while still blocking victims from loading them with safe_mode=True. * Implement tests for TFSMLayer safe mode functionality Add comprehensive tests for TFSMLayer safe_mode behavior: - test_safe_mode_direct_instantiation_allowed: Verifies direct TFSMLayer instantiation works as expected - test_safe_mode_from_config_blocked: Verifies from_config() raises ValueError when safe_mode=True - test_safe_mode_from_config_allowed_when_disabled: Verifies from_config() works with safe_mode=False - test_safe_mode_model_loading_blocked: Tests the full attack scenario where loading a .keras file with safe_mode=True is blocked * Clarify test docstrings in tfsm_layer_test.py Updated test docstrings for clarity on instantiation and loading behavior. * Invoke model with random input in tfsm_layer tests Added model invocation with random input to tests for TFSMLayer. * Set safe_mode default to True in from_config method * Update tfsm_layer_test.py * Update tfsm_layer_test.py * Update tfsm_layer_test.py to original * New test case tfsm_layer_test.py * Update Comments tfsm_layer.py * Update tfsm_layer_test.py * Update tfsm_layer.py * Update tfsm_layer.py to remove ruff errors * Update tfsm_layer.py * Update tfsm_layer_test.py * Update tfsm_layer.py * Update tfsm_layer.py * Update tfsm_layer_test.py * Update tfsm_layer.py format fix Changes in format * Update tfsm_layer.py * Update tfsm_layer_test.py * Update tfsm_layer.py * Update tfsm_layer_test.py * Fixes unnecessary changes tfsm_layer.py * Added new test case tfsm_layer_test.py * Set `safe_mode=None` in `from_config`, which fixes the unit tests. Also re-added empty lines. * Remove unneeded `custom_objects` in unit tests. --------- Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com> Co-authored-by: Fabien Hertschuh <1091026+hertschuh@users.noreply.github.com> * commit 3.12.1 cherry-pick changes --------- Co-authored-by: sarvesh patil <103917093+HyperPS@users.noreply.github.com> Co-authored-by: hertschuh <1091026+hertschuh@users.noreply.github.com> Co-authored-by: Manan Patel <70314133+0xManan@users.noreply.github.com> Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>
1 parent adbfd13 commit f704c88

File tree

6 files changed

+200
-75
lines changed

6 files changed

+200
-75
lines changed

keras/src/export/tfsm_layer.py

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
from keras.src import layers
33
from keras.src.api_export import keras_export
44
from keras.src.export.saved_model import _list_variables_used_by_fns
5+
from keras.src.saving import serialization_lib
56
from keras.src.utils.module_utils import tensorflow as tf
67

78

@@ -146,3 +147,36 @@ def get_config(self):
146147
"call_training_endpoint": self.call_training_endpoint,
147148
}
148149
return {**base_config, **config}
150+
151+
@classmethod
152+
def from_config(cls, config, custom_objects=None, safe_mode=None):
153+
"""Creates a TFSMLayer from its config.
154+
Args:
155+
config: A Python dictionary, typically the output of `get_config`.
156+
custom_objects: Optional dictionary mapping names to custom objects.
157+
safe_mode: Boolean, whether to disallow loading TFSMLayer.
158+
When `safe_mode=True`, loading is disallowed because TFSMLayer
159+
loads external SavedModels that may contain attacker-controlled
160+
executable graph code. Defaults to `True`.
161+
Returns:
162+
A TFSMLayer instance.
163+
"""
164+
# Follow the same pattern as Lambda layer for safe_mode handling
165+
effective_safe_mode = (
166+
safe_mode
167+
if safe_mode is not None
168+
else serialization_lib.in_safe_mode()
169+
)
170+
171+
if effective_safe_mode is not False:
172+
raise ValueError(
173+
"Requested the deserialization of a `TFSMLayer`, which "
174+
"loads an external SavedModel. This carries a potential risk "
175+
"of arbitrary code execution and thus it is disallowed by "
176+
"default. If you trust the source of the artifact, you can "
177+
"override this error by passing `safe_mode=False` to the "
178+
"loading function, or calling "
179+
"`keras.config.enable_unsafe_deserialization()."
180+
)
181+
182+
return cls(**config)

keras/src/export/tfsm_layer_test.py

Lines changed: 32 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -114,19 +114,48 @@ def test_serialization(self):
114114

115115
# Test reinstantiation from config
116116
config = reloaded_layer.get_config()
117-
rereloaded_layer = tfsm_layer.TFSMLayer.from_config(config)
117+
rereloaded_layer = tfsm_layer.TFSMLayer.from_config(
118+
config, safe_mode=False
119+
)
118120
self.assertAllClose(rereloaded_layer(ref_input), ref_output, atol=1e-7)
119121

120122
# Test whole model saving with reloaded layer inside
121123
model = models.Sequential([reloaded_layer])
122124
temp_model_filepath = os.path.join(self.get_temp_dir(), "m.keras")
123125
model.save(temp_model_filepath, save_format="keras_v3")
124126
reloaded_model = saving_lib.load_model(
125-
temp_model_filepath,
126-
custom_objects={"TFSMLayer": tfsm_layer.TFSMLayer},
127+
temp_model_filepath, safe_mode=False
127128
)
128129
self.assertAllClose(reloaded_model(ref_input), ref_output, atol=1e-7)
129130

131+
def test_safe_mode_blocks_model_loading(self):
132+
temp_filepath = os.path.join(self.get_temp_dir(), "exported_model")
133+
134+
# Create and export a model
135+
model = get_model()
136+
model(tf.random.normal((1, 10)))
137+
saved_model.export_saved_model(model, temp_filepath)
138+
139+
# Wrap SavedModel in TFSMLayer and save as .keras
140+
reloaded_layer = tfsm_layer.TFSMLayer(temp_filepath)
141+
wrapper_model = models.Sequential([reloaded_layer])
142+
143+
model_path = os.path.join(self.get_temp_dir(), "tfsm_model.keras")
144+
wrapper_model.save(model_path)
145+
146+
# Default safe_mode=True should block loading
147+
with self.assertRaisesRegex(
148+
ValueError,
149+
"arbitrary code execution",
150+
):
151+
saving_lib.load_model(model_path)
152+
153+
# Explicit opt-out should allow loading
154+
loaded_model = saving_lib.load_model(model_path, safe_mode=False)
155+
156+
x = tf.random.normal((2, 10))
157+
self.assertAllClose(loaded_model(x), wrapper_model(x))
158+
130159
def test_errors(self):
131160
# Test missing call endpoint
132161
temp_filepath = os.path.join(self.get_temp_dir(), "exported_model")

keras/src/saving/file_editor.py

Lines changed: 87 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -455,33 +455,114 @@ def resave_weights(self, filepath):
455455
def _extract_weights_from_store(self, data, metadata=None, inner_path=""):
456456
metadata = metadata or {}
457457

458+
# ------------------------------------------------------
459+
# Collect metadata for this HDF5 group
460+
# ------------------------------------------------------
458461
object_metadata = {}
459462
for k, v in data.attrs.items():
460463
object_metadata[k] = v
461464
if object_metadata:
462465
metadata[inner_path] = object_metadata
463466

464467
result = collections.OrderedDict()
468+
469+
# ------------------------------------------------------
470+
# Iterate over all keys in this HDF5 group
471+
# ------------------------------------------------------
465472
for key in data.keys():
466-
inner_path = f"{inner_path}/{key}"
473+
# IMPORTANT:
474+
# Never mutate inner_path; use local variable.
475+
current_inner_path = f"{inner_path}/{key}"
467476
value = data[key]
477+
478+
# ------------------------------------------------------
479+
# CASE 1 — HDF5 GROUP → RECURSE
480+
# ------------------------------------------------------
468481
if isinstance(value, h5py.Group):
482+
# Skip empty groups
469483
if len(value) == 0:
470484
continue
485+
486+
# Skip empty "vars" groups
471487
if "vars" in value.keys() and len(value["vars"]) == 0:
472488
continue
473489

474-
if hasattr(value, "keys"):
490+
# Recurse into "vars" subgroup when present
475491
if "vars" in value.keys():
476492
result[key], metadata = self._extract_weights_from_store(
477-
value["vars"], metadata=metadata, inner_path=inner_path
493+
value["vars"],
494+
metadata=metadata,
495+
inner_path=current_inner_path,
478496
)
479497
else:
498+
# Recurse normally
480499
result[key], metadata = self._extract_weights_from_store(
481-
value, metadata=metadata, inner_path=inner_path
500+
value,
501+
metadata=metadata,
502+
inner_path=current_inner_path,
482503
)
483-
else:
484-
result[key] = value[()]
504+
505+
continue # finished processing this key
506+
507+
# ------------------------------------------------------
508+
# CASE 2 — HDF5 DATASET → SAFE LOADING
509+
# ------------------------------------------------------
510+
511+
# Skip any objects that are not proper datasets
512+
if not isinstance(value, h5py.Dataset):
513+
continue
514+
515+
if value.external:
516+
raise ValueError(
517+
"Not allowed: H5 file Dataset with external links: "
518+
f"{value.external}"
519+
)
520+
521+
shape = value.shape
522+
dtype = value.dtype
523+
524+
# ------------------------------------------------------
525+
# Validate SHAPE (avoid malformed / malicious metadata)
526+
# ------------------------------------------------------
527+
528+
# No negative dimensions
529+
if any(dim < 0 for dim in shape):
530+
raise ValueError(
531+
"Malformed HDF5 dataset shape encountered in .keras file; "
532+
"negative dimension detected."
533+
)
534+
535+
# Prevent absurdly high-rank tensors
536+
if len(shape) > 64:
537+
raise ValueError(
538+
"Malformed HDF5 dataset shape encountered in .keras file; "
539+
"tensor rank exceeds safety limit."
540+
)
541+
542+
# Safe product computation (Python int is unbounded)
543+
num_elems = int(np.prod(shape))
544+
545+
# ------------------------------------------------------
546+
# Validate TOTAL memory size
547+
# ------------------------------------------------------
548+
MAX_BYTES = 1 << 32 # 4 GiB
549+
550+
size_bytes = num_elems * dtype.itemsize
551+
552+
if size_bytes > MAX_BYTES:
553+
raise ValueError(
554+
f"HDF5 dataset too large to load safely "
555+
f"({size_bytes} bytes; limit is {MAX_BYTES})."
556+
)
557+
558+
# ------------------------------------------------------
559+
# SAFE — load dataset (guaranteed ≤ 4 GiB)
560+
# ------------------------------------------------------
561+
result[key] = value[()]
562+
563+
# ------------------------------------------------------
564+
# Return final tree and metadata
565+
# ------------------------------------------------------
485566
return result, metadata
486567

487568
def _generate_filepath_info(self, rich_style=False):

keras/src/saving/saving_lib.py

Lines changed: 46 additions & 47 deletions
Original file line numberDiff line numberDiff line change
@@ -796,7 +796,8 @@ def _load_state(
796796
try:
797797
saveable.load_own_variables(weights_store.get(inner_path))
798798
except Exception as e:
799-
failed_saveables.add(id(saveable))
799+
if failed_saveables is not None:
800+
failed_saveables.add(id(saveable))
800801
error_msgs[id(saveable)] = saveable, e
801802
failure = True
802803
else:
@@ -807,7 +808,8 @@ def _load_state(
807808
try:
808809
saveable.load_assets(assets_store.get(inner_path))
809810
except Exception as e:
810-
failed_saveables.add(id(saveable))
811+
if failed_saveables is not None:
812+
failed_saveables.add(id(saveable))
811813
error_msgs[id(saveable)] = saveable, e
812814
failure = True
813815
else:
@@ -855,7 +857,7 @@ def _load_state(
855857
if not failure:
856858
if visited_saveables is not None and newly_failed <= 0:
857859
visited_saveables.add(id(saveable))
858-
if id(saveable) in failed_saveables:
860+
if failed_saveables is not None and id(saveable) in failed_saveables:
859861
failed_saveables.remove(id(saveable))
860862
error_msgs.pop(id(saveable))
861863

@@ -1035,6 +1037,25 @@ def __bool__(self):
10351037
# will mistakenly using `__len__` to determine the value.
10361038
return self.h5_file.__bool__()
10371039

1040+
def _verify_group(self, group):
1041+
if not isinstance(group, h5py.Group):
1042+
raise ValueError(
1043+
f"Invalid H5 file, expected Group but received {type(group)}"
1044+
)
1045+
return group
1046+
1047+
def _verify_dataset(self, dataset):
1048+
if not isinstance(dataset, h5py.Dataset):
1049+
raise ValueError(
1050+
f"Invalid H5 file, expected Dataset, received {type(dataset)}"
1051+
)
1052+
if dataset.external:
1053+
raise ValueError(
1054+
"Not allowed: H5 file Dataset with external links: "
1055+
f"{dataset.external}"
1056+
)
1057+
return dataset
1058+
10381059
def _get_h5_file(self, path_or_io, mode=None):
10391060
mode = mode or self.mode
10401061
if mode not in ("r", "w", "a"):
@@ -1094,15 +1115,19 @@ def get(self, path):
10941115
self._h5_entry_group = {} # Defaults to an empty dict if not found.
10951116
if not path:
10961117
if "vars" in self.h5_file:
1097-
self._h5_entry_group = self.h5_file["vars"]
1118+
self._h5_entry_group = self._verify_group(self.h5_file["vars"])
10981119
elif path in self.h5_file and "vars" in self.h5_file[path]:
1099-
self._h5_entry_group = self.h5_file[path]["vars"]
1120+
self._h5_entry_group = self._verify_group(
1121+
self._verify_group(self.h5_file[path])["vars"]
1122+
)
11001123
else:
11011124
# No hit. Fix for 2.13 compatibility.
11021125
if "_layer_checkpoint_dependencies" in self.h5_file:
11031126
path = path.replace("layers", "_layer_checkpoint_dependencies")
11041127
if path in self.h5_file and "vars" in self.h5_file[path]:
1105-
self._h5_entry_group = self.h5_file[path]["vars"]
1128+
self._h5_entry_group = self._verify_group(
1129+
self._verify_group(self.h5_file[path])["vars"]
1130+
)
11061131
self._h5_entry_initialized = True
11071132
return self
11081133

@@ -1134,25 +1159,15 @@ def __len__(self):
11341159
def keys(self):
11351160
return self._h5_entry_group.keys()
11361161

1137-
def items(self):
1138-
return self._h5_entry_group.items()
1139-
1140-
def values(self):
1141-
return self._h5_entry_group.values()
1142-
11431162
def __getitem__(self, key):
1144-
value = self._h5_entry_group[key]
1163+
value = self._verify_dataset(self._h5_entry_group[key])
11451164
if (
11461165
hasattr(value, "attrs")
11471166
and "dtype" in value.attrs
11481167
and value.attrs["dtype"] == "bfloat16"
11491168
):
11501169
value = np.array(value, dtype=ml_dtypes.bfloat16)
1151-
elif (
1152-
hasattr(value, "shape")
1153-
and hasattr(value, "dtype")
1154-
and not isinstance(value, np.ndarray)
1155-
):
1170+
elif not isinstance(value, np.ndarray):
11561171
value = np.array(value)
11571172
return value
11581173

@@ -1355,25 +1370,25 @@ def _switch_h5_file(self, filename, mode):
13551370
self._get_h5_group(self._h5_entry_path)
13561371

13571372
def _restore_h5_file(self):
1358-
"""Ensure the current shard is the last one created.
1359-
1360-
We use mode="a" to avoid truncating the file during the switching.
1361-
"""
1373+
"""Ensure the current shard is the last one created."""
13621374
if (
13631375
pathlib.Path(self.h5_file.filename).name
13641376
!= self.current_shard_path.name
13651377
):
1366-
self._switch_h5_file(self.current_shard_path.name, mode="a")
1378+
mode = "a" if self.mode == "w" else "r"
1379+
self._switch_h5_file(self.current_shard_path.name, mode=mode)
13671380

13681381
# H5 entry level methods.
13691382

13701383
def _get_h5_group(self, path):
13711384
"""Get the H5 entry group. If it doesn't exist, return an empty dict."""
13721385
try:
13731386
if not path:
1374-
self._h5_entry_group = self.h5_file["vars"]
1387+
self._h5_entry_group = self._verify_group(self.h5_file["vars"])
13751388
else:
1376-
self._h5_entry_group = self.h5_file[path]["vars"]
1389+
self._h5_entry_group = self._verify_group(
1390+
self._verify_group(self.h5_file[path])["vars"]
1391+
)
13771392
self._h5_entry_initialized = True
13781393
except KeyError:
13791394
self._h5_entry_group = {}
@@ -1392,33 +1407,17 @@ def __len__(self):
13921407
return total_len
13931408

13941409
def keys(self):
1395-
keys = set(self._h5_entry_group.keys())
1410+
keys = []
1411+
current_shard_keys = list(self._h5_entry_group.keys())
13961412
for filename in self.current_shard_filenames:
13971413
if filename == self.current_shard_path.name:
1398-
continue
1399-
self._switch_h5_file(filename, mode="r")
1400-
keys.update(self._h5_entry_group.keys())
1414+
keys += current_shard_keys
1415+
else:
1416+
self._switch_h5_file(filename, mode="r")
1417+
keys += list(self._h5_entry_group.keys())
14011418
self._restore_h5_file()
14021419
return keys
14031420

1404-
def items(self):
1405-
yield from self._h5_entry_group.items()
1406-
for filename in self.current_shard_filenames:
1407-
if filename == self.current_shard_path.name:
1408-
continue
1409-
self._switch_h5_file(filename, mode="r")
1410-
yield from self._h5_entry_group.items()
1411-
self._restore_h5_file()
1412-
1413-
def values(self):
1414-
yield from self._h5_entry_group.values()
1415-
for filename in self.current_shard_filenames:
1416-
if filename == self.current_shard_path.name:
1417-
continue
1418-
self._switch_h5_file(filename, mode="r")
1419-
yield from self._h5_entry_group.values()
1420-
self._restore_h5_file()
1421-
14221421
def __getitem__(self, key):
14231422
if key in self._h5_entry_group:
14241423
return super().__getitem__(key)

0 commit comments

Comments
 (0)