You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
Add consistent support for name serialization of Operations. (#21373)
This is to solve an inconsistency in the saving / reloading of ops. Some ops behave differently.
Consider this code:
```python
input = keras.layers.Input(shape=(4,), dtype="float32")
output = keras.ops.abs(input)
model = keras.models.Model(input, output)
json = model.to_json()
reloaded_model = keras.models.model_from_json(json)
reloaded_json = reloaded_model.to_json()
```
The reloaded model JSON is the same as the original model JSON. The `abs` op is serialized as
- `{"module": "keras.src.ops.numpy", "class_name": "Absolute", "config": {"name": "absolute"}, "registered_name": "Absolute", "name": "absolute", ...}`
Consider the same code with `abs` replaced with `sum`:
```python
input = keras.layers.Input(shape=(4,), dtype="float32")
output = keras.ops.abs(input)
model = keras.models.Model(input, output)
json = model.to_json()
reloaded_model = keras.models.model_from_json(json)
reloaded_json = reloaded_model.to_json()
```
The reloaded model JSON is different from the original JSON.
- `{"module": "keras.src.ops.numpy", "class_name": "Sum", "config": {"axis": null, "keepdims": false}, "registered_name": "Sum", "name": "sum", ...}`
- `{"module": "keras.src.ops.numpy", "class_name": "Sum", "config": {"axis": null, "keepdims": false}, "registered_name": "Sum", "name": "sum_1", ...}`
The reloaded `sum` op now has a name `"sum_1"` instead of `"sum"`.
This is because:
- `Abs` does not define `__init__` and inherits the `Operation.__init__` that has a `name` parameter.
- `Sum` defines `__init__` without a `name` parameter.
We want saving / reloading to be idempotent. Even though users cannot control the name of ops (although it would be easy to add), the auto-assigned names should be saved and reloaded. For this, `name` has to be supported in `__init__`.
This PR changes the behavior of serialization from `_auto_config` to allow `name` to be returned when the `__init__` signature has `name` or `**kwargs`. The existing logic had a bug. Also switched to `inspect.signature` which is the recommended API in Python 3.
Note that adding `name` as a parameter to all ops `__init__`s will be done as a separate PR.
0 commit comments