Skip to content

Commit 693764a

Browse files
authored
Make vectorized_map op serializable. (#21597)
All ops that take tensors as inputs are in scope for serialization. They get serialized if they appear as part of a functional model. The `VectorizedMap` op class needs a custom `get_config()` and `from_config()` to handle the `function` parameter, which is not a simple type. The fallback `get_config` logic only handles simple types. Note that some other ops will need a similar change, which will come in a separate PR because it requires a another change.
1 parent 0887984 commit 693764a

File tree

3 files changed

+29
-0
lines changed

3 files changed

+29
-0
lines changed

keras/src/backend/openvino/excluded_concrete_tests.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -186,6 +186,7 @@ CoreOpsCorrectnessTest::test_scatter
186186
CoreOpsCorrectnessTest::test_switch
187187
CoreOpsCorrectnessTest::test_unstack
188188
CoreOpsCorrectnessTest::test_vectorized_map
189+
CoreOpsBehaviorTests::test_vectorized_map_serialization
189190
ExtractSequencesOpTest::test_extract_sequences_call
190191
InTopKTest::test_in_top_k_call
191192
MathOpsCorrectnessTest::test_erfinv_operation_basic

keras/src/ops/core.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
from keras.src.backend import any_symbolic_tensors
99
from keras.src.backend.common.backend_utils import slice_along_axis
1010
from keras.src.ops.operation import Operation
11+
from keras.src.saving import serialization_lib
1112
from keras.src.utils import traceback_utils
1213

1314

@@ -1105,6 +1106,19 @@ def append_batch_axis(t):
11051106
y = tree.map_structure(append_batch_axis, y)
11061107
return y
11071108

1109+
def get_config(self):
1110+
config = super().get_config()
1111+
config.update({"function": self.function})
1112+
return config
1113+
1114+
@classmethod
1115+
def from_config(cls, config):
1116+
config = config.copy()
1117+
config["function"] = serialization_lib.deserialize_keras_object(
1118+
config["function"]
1119+
)
1120+
return cls(**config)
1121+
11081122

11091123
@keras_export("keras.ops.vectorized_map")
11101124
def vectorized_map(function, elements):

keras/src/ops/core_test.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,9 @@
1616
from keras.src import tree
1717
from keras.src.backend.common import dtypes
1818
from keras.src.backend.common.keras_tensor import KerasTensor
19+
from keras.src.layers.core import input_layer
1920
from keras.src.ops import core
21+
from keras.src.saving import object_registration
2022
from keras.src.testing.test_utils import named_product
2123

2224

@@ -1622,6 +1624,18 @@ def test_stop_gradient_compute_output_spec(self):
16221624
self.assertEqual(output_spec.shape, variable.shape)
16231625
self.assertEqual(output_spec.dtype, variable.dtype)
16241626

1627+
def test_vectorized_map_serialization(self):
1628+
@object_registration.register_keras_serializable()
1629+
def f(x):
1630+
return x + x
1631+
1632+
inputs = input_layer.Input((10,), dtype="float32")
1633+
outputs = core.vectorized_map(f, inputs)
1634+
model = models.Functional(inputs, outputs)
1635+
reloaded_model = model.from_config(model.get_config())
1636+
x = np.random.rand(5, 10).astype("float32")
1637+
self.assertAllClose(model(x), reloaded_model(x))
1638+
16251639
def test_while_loop_output_spec(self):
16261640
# Define dummy cond and body functions
16271641
def cond(x):

0 commit comments

Comments
 (0)