Skip to content

Commit ce0d278

Browse files
authored
Disable torch.load in TorchModuleWrapper when in safe mode. (#21575)
Raise an exception and explain the user about the risks.
1 parent cd7ec31 commit ce0d278

File tree

2 files changed

+42
-12
lines changed

2 files changed

+42
-12
lines changed

keras/src/utils/torch_utils.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
from keras.src.layers import Layer
99
from keras.src.ops import convert_to_numpy
1010
from keras.src.ops import convert_to_tensor
11+
from keras.src.saving.serialization_lib import in_safe_mode
1112

1213

1314
@keras_export("keras.layers.TorchModuleWrapper")
@@ -166,6 +167,17 @@ def from_config(cls, config):
166167
import torch
167168

168169
if "module" in config:
170+
if in_safe_mode():
171+
raise ValueError(
172+
"Requested the deserialization of a `torch.nn.Module` "
173+
"object via `torch.load()`. This carries a potential risk "
174+
"of arbitrary code execution and thus it is disallowed by "
175+
"default. If you trust the source of the saved model, you "
176+
"can pass `safe_mode=False` to the loading function in "
177+
"order to allow `torch.nn.Module` loading, or call "
178+
"`keras.config.enable_unsafe_deserialization()`."
179+
)
180+
169181
# Decode the base64 string back to bytes
170182
buffer_bytes = base64.b64decode(config["module"].encode("ascii"))
171183
buffer = io.BytesIO(buffer_bytes)

keras/src/utils/torch_utils_test.py

Lines changed: 30 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -248,26 +248,44 @@ def test_build_model(self):
248248
self.assertEqual(model.predict(np.zeros([5, 4])).shape, (5, 16))
249249
self.assertEqual(model(np.zeros([5, 4])).shape, (5, 16))
250250

251-
def test_save_load(self):
251+
@parameterized.named_parameters(
252+
("safe_mode", True),
253+
("unsafe_mode", False),
254+
)
255+
def test_save_load(self, safe_mode):
252256
@keras.saving.register_keras_serializable()
253257
class M(keras.Model):
254-
def __init__(self, channels=10, **kwargs):
255-
super().__init__()
256-
self.sequence = torch.nn.Sequential(
257-
torch.nn.Conv2d(1, channels, kernel_size=(3, 3)),
258-
)
258+
def __init__(self, module, **kwargs):
259+
super().__init__(**kwargs)
260+
self.module = module
259261

260262
def call(self, x):
261-
return self.sequence(x)
263+
return self.module(x)
262264

263-
m = M()
265+
def get_config(self):
266+
base_config = super().get_config()
267+
config = {"module": self.module}
268+
return {**base_config, **config}
269+
270+
@classmethod
271+
def from_config(cls, config):
272+
config["module"] = saving.deserialize_keras_object(
273+
config["module"]
274+
)
275+
return cls(**config)
276+
277+
m = M(torch.nn.Conv2d(1, 10, kernel_size=(3, 3)))
264278
device = get_device() # Get the current device (e.g., "cuda" or "cpu")
265279
x = torch.ones(
266280
(10, 1, 28, 28), device=device
267281
) # Place input on the correct device
268-
m(x)
282+
ref_output = m(x)
269283
temp_filepath = os.path.join(self.get_temp_dir(), "mymodel.keras")
270284
m.save(temp_filepath)
271-
new_model = saving.load_model(temp_filepath)
272-
for ref_w, new_w in zip(m.get_weights(), new_model.get_weights()):
273-
self.assertAllClose(ref_w, new_w, atol=1e-5)
285+
286+
if safe_mode:
287+
with self.assertRaisesRegex(ValueError, "arbitrary code execution"):
288+
saving.load_model(temp_filepath, safe_mode=safe_mode)
289+
else:
290+
new_model = saving.load_model(temp_filepath, safe_mode=safe_mode)
291+
self.assertAllClose(new_model(x), ref_output)

0 commit comments

Comments
 (0)