Skip to content

Commit 11f737e

Browse files
authored
Add consistent support for name serialization of Operations. (#21373)
This is to solve an inconsistency in the saving / reloading of ops. Some ops behave differently. Consider this code: ```python input = keras.layers.Input(shape=(4,), dtype="float32") output = keras.ops.abs(input) model = keras.models.Model(input, output) json = model.to_json() reloaded_model = keras.models.model_from_json(json) reloaded_json = reloaded_model.to_json() ``` The reloaded model JSON is the same as the original model JSON. The `abs` op is serialized as - `{"module": "keras.src.ops.numpy", "class_name": "Absolute", "config": {"name": "absolute"}, "registered_name": "Absolute", "name": "absolute", ...}` Consider the same code with `abs` replaced with `sum`: ```python input = keras.layers.Input(shape=(4,), dtype="float32") output = keras.ops.abs(input) model = keras.models.Model(input, output) json = model.to_json() reloaded_model = keras.models.model_from_json(json) reloaded_json = reloaded_model.to_json() ``` The reloaded model JSON is different from the original JSON. - `{"module": "keras.src.ops.numpy", "class_name": "Sum", "config": {"axis": null, "keepdims": false}, "registered_name": "Sum", "name": "sum", ...}` - `{"module": "keras.src.ops.numpy", "class_name": "Sum", "config": {"axis": null, "keepdims": false}, "registered_name": "Sum", "name": "sum_1", ...}` The reloaded `sum` op now has a name `"sum_1"` instead of `"sum"`. This is because: - `Abs` does not define `__init__` and inherits the `Operation.__init__` that has a `name` parameter. - `Sum` defines `__init__` without a `name` parameter. We want saving / reloading to be idempotent. Even though users cannot control the name of ops (although it would be easy to add), the auto-assigned names should be saved and reloaded. For this, `name` has to be supported in `__init__`. This PR changes the behavior of serialization from `_auto_config` to allow `name` to be returned when the `__init__` signature has `name` or `**kwargs`. The existing logic had a bug. Also switched to `inspect.signature` which is the recommended API in Python 3. Note that adding `name` as a parameter to all ops `__init__`s will be done as a separate PR.
1 parent 08ad93b commit 11f737e

File tree

2 files changed

+152
-17
lines changed

2 files changed

+152
-17
lines changed

keras/src/ops/operation.py

Lines changed: 9 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -171,13 +171,16 @@ def get_config(self):
171171
# In this case the subclass doesn't implement get_config():
172172
# Let's see if we can autogenerate it.
173173
if getattr(self, "_auto_config", None) is not None:
174-
xtra_args = set(config.keys())
175174
config.update(self._auto_config.config)
176-
# Remove args non explicitly supported
177-
argspec = inspect.getfullargspec(self.__init__)
178-
if argspec.varkw != "kwargs":
179-
for key in xtra_args - xtra_args.intersection(argspec.args[1:]):
180-
config.pop(key, None)
175+
init_params = inspect.signature(self.__init__).parameters
176+
init_has_name = "name" in init_params
177+
init_has_kwargs = (
178+
"kwargs" in init_params
179+
and init_params["kwargs"].kind == inspect.Parameter.VAR_KEYWORD
180+
)
181+
if not init_has_name and not init_has_kwargs:
182+
# We can't pass `name` back to `__init__`, remove it.
183+
config.pop("name", None)
181184
return config
182185
else:
183186
raise NotImplementedError(

keras/src/ops/operation_test.py

Lines changed: 143 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -30,20 +30,71 @@ def compute_output_spec(self, x):
3030

3131

3232
class OpWithCustomConstructor(operation.Operation):
33-
def __init__(self, alpha, mode="foo"):
33+
def __init__(self, alpha, *, name=None):
34+
super().__init__(name=name)
35+
self.alpha = alpha
36+
37+
def call(self, x):
38+
return self.alpha * x
39+
40+
def compute_output_spec(self, x):
41+
return keras_tensor.KerasTensor(x.shape, x.dtype)
42+
43+
44+
class OpWithCustomConstructorNoName(operation.Operation):
45+
def __init__(self, alpha):
3446
super().__init__()
3547
self.alpha = alpha
36-
self.mode = mode
3748

3849
def call(self, x):
39-
if self.mode == "foo":
40-
return x
4150
return self.alpha * x
4251

4352
def compute_output_spec(self, x):
4453
return keras_tensor.KerasTensor(x.shape, x.dtype)
4554

4655

56+
class OpWithKwargsInConstructor(operation.Operation):
57+
def __init__(self, alpha, **kwargs):
58+
super().__init__(**kwargs)
59+
self.alpha = alpha
60+
61+
def call(self, x):
62+
return self.alpha * x
63+
64+
def compute_output_spec(self, x):
65+
return keras_tensor.KerasTensor(x.shape, x.dtype)
66+
67+
68+
class OpWithCustomConstructorGetConfig(operation.Operation):
69+
def __init__(self, alpha, *, name=None):
70+
super().__init__(name=name)
71+
self.alpha = alpha
72+
73+
def call(self, x):
74+
return self.alpha * x
75+
76+
def compute_output_spec(self, x):
77+
return keras_tensor.KerasTensor(x.shape, x.dtype)
78+
79+
def get_config(self):
80+
return {**super().get_config(), "alpha": self.alpha}
81+
82+
83+
class OpWithKwargsInConstructorGetConfig(operation.Operation):
84+
def __init__(self, alpha, **kwargs):
85+
super().__init__(**kwargs)
86+
self.alpha = alpha
87+
88+
def call(self, x):
89+
return self.alpha * x
90+
91+
def compute_output_spec(self, x):
92+
return keras_tensor.KerasTensor(x.shape, x.dtype)
93+
94+
def get_config(self):
95+
return {**super().get_config(), "alpha": self.alpha}
96+
97+
4798
class OperationTest(testing.TestCase):
4899
def test_symbolic_call(self):
49100
x = keras_tensor.KerasTensor(shape=(2, 3), name="x")
@@ -129,19 +180,100 @@ def test_eager_call(self):
129180
self.assertAllClose(out[0], np.ones((2, 3)))
130181
self.assertAllClose(out[1], np.ones((2, 3)) + 1)
131182

132-
def test_serialization(self):
133-
op = OpWithMultipleOutputs(name="test_op")
183+
def test_serialization_with_default_init_and_get_config(self):
184+
# Explicit name passed in constructor is serialized and deserialized.
185+
op = OpWithMultipleInputs(name="test_op")
134186
config = op.get_config()
135187
self.assertEqual(config, {"name": "test_op"})
136-
op = OpWithMultipleOutputs.from_config(config)
137-
self.assertEqual(op.name, "test_op")
188+
revived = OpWithMultipleInputs.from_config(config)
189+
self.assertEqual(revived.get_config(), config)
190+
self.assertEqual(revived.name, op.name)
191+
192+
# Auto generated name is serialized and deserialized.
193+
op = OpWithMultipleInputs()
194+
config = op.get_config()
195+
self.assertEqual(config, {"name": op.name})
196+
revived = OpWithMultipleInputs.from_config(config)
197+
self.assertEqual(revived.get_config(), config)
198+
self.assertEqual(revived.name, op.name)
199+
200+
def test_serialization_custom_constructor_with_name_auto_config(self):
201+
# Explicit name passed in constructor is serialized and deserialized.
202+
op = OpWithCustomConstructor(alpha=0.2, name="test_op")
203+
config = op.get_config()
204+
self.assertEqual(config, {"alpha": 0.2, "name": "test_op"})
205+
revived = OpWithCustomConstructor.from_config(config)
206+
self.assertEqual(revived.get_config(), config)
207+
self.assertEqual(revived.name, op.name)
138208

139-
def test_autoconfig(self):
140-
op = OpWithCustomConstructor(alpha=0.2, mode="bar")
209+
# Auto generated name is serialized and deserialized.
210+
op = OpWithCustomConstructor(alpha=0.2)
141211
config = op.get_config()
142-
self.assertEqual(config, {"alpha": 0.2, "mode": "bar"})
212+
self.assertEqual(config, {"alpha": 0.2, "name": op.name})
143213
revived = OpWithCustomConstructor.from_config(config)
144214
self.assertEqual(revived.get_config(), config)
215+
self.assertEqual(revived.name, op.name)
216+
217+
def test_serialization_custom_constructor_with_no_name_auto_config(self):
218+
# Auto generated name is not serialized.
219+
op = OpWithCustomConstructorNoName(alpha=0.2)
220+
config = op.get_config()
221+
self.assertEqual(config, {"alpha": 0.2})
222+
revived = OpWithCustomConstructorNoName.from_config(config)
223+
self.assertEqual(revived.get_config(), config)
224+
225+
def test_serialization_custom_constructor_with_kwargs_auto_config(self):
226+
# Explicit name passed in constructor is serialized and deserialized.
227+
op = OpWithKwargsInConstructor(alpha=0.2, name="test_op")
228+
config = op.get_config()
229+
self.assertEqual(config, {"alpha": 0.2, "name": "test_op"})
230+
revived = OpWithKwargsInConstructor.from_config(config)
231+
self.assertEqual(revived.get_config(), config)
232+
self.assertEqual(revived.name, op.name)
233+
234+
# Auto generated name is serialized and deserialized.
235+
op = OpWithKwargsInConstructor(alpha=0.2)
236+
config = op.get_config()
237+
self.assertEqual(config, {"alpha": 0.2, "name": op.name})
238+
revived = OpWithKwargsInConstructor.from_config(config)
239+
self.assertEqual(revived.get_config(), config)
240+
self.assertEqual(revived.name, op.name)
241+
242+
def test_serialization_custom_constructor_custom_get_config(self):
243+
# Explicit name passed in constructor is serialized and deserialized.
244+
op = OpWithCustomConstructorGetConfig(alpha=0.2, name="test_op")
245+
config = op.get_config()
246+
self.assertEqual(config, {"alpha": 0.2, "name": "test_op"})
247+
revived = OpWithCustomConstructorGetConfig.from_config(config)
248+
self.assertEqual(revived.get_config(), config)
249+
self.assertEqual(revived.name, op.name)
250+
251+
# Auto generated name is serialized and deserialized.
252+
op = OpWithCustomConstructorGetConfig(alpha=0.2)
253+
config = op.get_config()
254+
self.assertEqual(config, {"alpha": 0.2, "name": op.name})
255+
revived = OpWithCustomConstructorGetConfig.from_config(config)
256+
self.assertEqual(revived.get_config(), config)
257+
self.assertEqual(revived.name, op.name)
258+
259+
def test_serialization_custom_constructor_with_kwargs_custom_get_config(
260+
self,
261+
):
262+
# Explicit name passed in constructor is serialized and deserialized.
263+
op = OpWithKwargsInConstructorGetConfig(alpha=0.2, name="test_op")
264+
config = op.get_config()
265+
self.assertEqual(config, {"alpha": 0.2, "name": "test_op"})
266+
revived = OpWithKwargsInConstructorGetConfig.from_config(config)
267+
self.assertEqual(revived.get_config(), config)
268+
self.assertEqual(revived.name, op.name)
269+
270+
# Auto generated name is serialized and deserialized.
271+
op = OpWithKwargsInConstructorGetConfig(alpha=0.2)
272+
config = op.get_config()
273+
self.assertEqual(config, {"alpha": 0.2, "name": op.name})
274+
revived = OpWithKwargsInConstructorGetConfig.from_config(config)
275+
self.assertEqual(revived.get_config(), config)
276+
self.assertEqual(revived.name, op.name)
145277

146278
@skip_if_backend(
147279
"openvino", "Can not constant fold eltwise node by CPU plugin"

0 commit comments

Comments
 (0)