Skip to content

Commit e704b46

Browse files
fix torch module wrapper serialization error (#21505)
* fix torch module wrapper serialization error * make fix narrower * address review comments * fix gpu tests * fix error
1 parent a8c245f commit e704b46

File tree

2 files changed

+32
-2
lines changed

2 files changed

+32
-2
lines changed

keras/src/utils/torch_utils.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
import base64
12
import io
23

34
from packaging.version import parse
@@ -152,8 +153,10 @@ def get_config(self):
152153

153154
buffer = io.BytesIO()
154155
torch.save(self.module, buffer)
156+
# Encode the buffer using base64 to ensure safe serialization
157+
buffer_b64 = base64.b64encode(buffer.getvalue()).decode("ascii")
155158
config = {
156-
"module": buffer.getvalue(),
159+
"module": buffer_b64,
157160
"output_shape": self.output_shape,
158161
}
159162
return {**base_config, **config}
@@ -163,7 +166,9 @@ def from_config(cls, config):
163166
import torch
164167

165168
if "module" in config:
166-
buffer = io.BytesIO(config["module"])
169+
# Decode the base64 string back to bytes
170+
buffer_bytes = base64.b64decode(config["module"].encode("ascii"))
171+
buffer = io.BytesIO(buffer_bytes)
167172
config["module"] = torch.load(buffer, weights_only=False)
168173
return cls(**config)
169174

keras/src/utils/torch_utils_test.py

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
from keras.src import models
1212
from keras.src import saving
1313
from keras.src import testing
14+
from keras.src.backend.torch.core import get_device
1415
from keras.src.utils.torch_utils import TorchModuleWrapper
1516

1617

@@ -246,3 +247,27 @@ def test_build_model(self):
246247
model = keras.Model(x, y)
247248
self.assertEqual(model.predict(np.zeros([5, 4])).shape, (5, 16))
248249
self.assertEqual(model(np.zeros([5, 4])).shape, (5, 16))
250+
251+
def test_save_load(self):
252+
@keras.saving.register_keras_serializable()
253+
class M(keras.Model):
254+
def __init__(self, channels=10, **kwargs):
255+
super().__init__()
256+
self.sequence = torch.nn.Sequential(
257+
torch.nn.Conv2d(1, channels, kernel_size=(3, 3)),
258+
)
259+
260+
def call(self, x):
261+
return self.sequence(x)
262+
263+
m = M()
264+
device = get_device() # Get the current device (e.g., "cuda" or "cpu")
265+
x = torch.ones(
266+
(10, 1, 28, 28), device=device
267+
) # Place input on the correct device
268+
m(x)
269+
temp_filepath = os.path.join(self.get_temp_dir(), "mymodel.keras")
270+
m.save(temp_filepath)
271+
new_model = saving.load_model(temp_filepath)
272+
for ref_w, new_w in zip(m.get_weights(), new_model.get_weights()):
273+
self.assertAllClose(ref_w, new_w, atol=1e-5)

0 commit comments

Comments
 (0)