Skip to content
This repository was archived by the owner on Mar 10, 2026. It is now read-only.

Commit 180e134

Browse files
authored
Slice serialization (#932)
* Add serialization of slices. * Fix case when `getitem` doesn't accept list as input.
1 parent a354797 commit 180e134

File tree

3 files changed

+31
-0
lines changed

3 files changed

+31
-0
lines changed

keras_core/ops/numpy.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2553,6 +2553,8 @@ def full_like(x, fill_value, dtype=None):
25532553

25542554
class GetItem(Operation):
25552555
def call(self, x, key):
2556+
if isinstance(key, list):
2557+
key = tuple(key)
25562558
return x[key]
25572559

25582560
def compute_output_spec(self, x, key):

keras_core/saving/serialization_lib.py

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -153,6 +153,15 @@ def serialize_keras_object(obj):
153153
"class_name": "__bytes__",
154154
"config": {"value": obj.decode("utf-8")},
155155
}
156+
if isinstance(obj, slice):
157+
return {
158+
"class_name": "__slice__",
159+
"config": {
160+
"start": serialize_keras_object(obj.start),
161+
"stop": serialize_keras_object(obj.stop),
162+
"step": serialize_keras_object(obj.step),
163+
},
164+
}
156165
if isinstance(obj, backend.KerasTensor):
157166
history = getattr(obj, "_keras_history", None)
158167
if history:
@@ -602,6 +611,24 @@ class ModifiedMeanSquaredError(keras_core.losses.MeanSquaredError):
602611
return np.array(inner_config["value"], dtype=inner_config["dtype"])
603612
if config["class_name"] == "__bytes__":
604613
return inner_config["value"].encode("utf-8")
614+
if config["class_name"] == "__slice__":
615+
return slice(
616+
deserialize_keras_object(
617+
inner_config["start"],
618+
custom_objects=custom_objects,
619+
safe_mode=safe_mode,
620+
),
621+
deserialize_keras_object(
622+
inner_config["stop"],
623+
custom_objects=custom_objects,
624+
safe_mode=safe_mode,
625+
),
626+
deserialize_keras_object(
627+
inner_config["step"],
628+
custom_objects=custom_objects,
629+
safe_mode=safe_mode,
630+
),
631+
)
605632
if config["class_name"] == "__lambda__":
606633
if safe_mode:
607634
raise ValueError(

keras_core/saving/serialization_lib_test.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -75,6 +75,8 @@ def test_simple_objects(self):
7575
["hello", 0, "world", 1.0, True],
7676
{"1": "hello", "2": 0, "3": True},
7777
{"1": "hello", "2": [True, False]},
78+
slice(None, 20, 1),
79+
slice(None, np.array([0, 1]), 1),
7880
]:
7981
serialized, _, reserialized = self.roundtrip(obj)
8082
self.assertEqual(serialized, reserialized)

0 commit comments

Comments
 (0)