|
3 | 3 | import builtins |
4 | 4 | import inspect |
5 | 5 | import keras |
| 6 | +import functools |
6 | 7 | import numpy as np |
7 | 8 | import sys |
8 | 9 | from warnings import warn |
9 | 10 |
|
10 | 11 | # this import needs to be exactly like this to work with monkey patching |
11 | | -from keras.saving import deserialize_keras_object |
| 12 | +from keras.saving import deserialize_keras_object, get_registered_object, get_registered_name |
| 13 | +from keras.src.saving.serialization_lib import SerializableDict |
| 14 | +from keras import dtype_policies |
| 15 | +from keras import tree |
12 | 16 |
|
13 | 17 | from .context_managers import monkey_patch |
14 | 18 | from .decorators import allow_args |
@@ -95,6 +99,10 @@ def deserialize(config: dict, custom_objects=None, safe_mode=True, **kwargs): |
95 | 99 | return obj |
96 | 100 |
|
97 | 101 |
|
| 102 | +def _deserializing_from_config(cls, config, custom_objects=None): |
| 103 | + return cls(**deserialize(config, custom_objects=custom_objects)) |
| 104 | + |
| 105 | + |
98 | 106 | @allow_args |
99 | 107 | def serializable(cls, package: str, name: str | None = None, disable_module_check: bool = False): |
100 | 108 | """Register class as Keras serializable. |
@@ -143,6 +151,68 @@ def serializable(cls, package: str, name: str | None = None, disable_module_chec |
143 | 151 | if name is None: |
144 | 152 | name = copy(cls.__name__) |
145 | 153 |
|
| 154 | + def init_decorator(original_init): |
| 155 | + # Adds auto-config behavior after the __init__ function. This extends the auto-config capabilities provided |
| 156 | + # by keras.Operation (base class of keras.Layer) with support for all serializable objects. |
| 157 | + # This produces a serialized config that has to be deserialized properly, see below. |
| 158 | + @functools.wraps(original_init) |
| 159 | + def wrapper(instance, *args, **kwargs): |
| 160 | + original_init(instance, *args, **kwargs) |
| 161 | + |
| 162 | + # Generate a config to be returned by default by `get_config()`. |
| 163 | + # Adapted from keras.Operation. |
| 164 | + kwargs = kwargs.copy() |
| 165 | + arg_names = inspect.getfullargspec(original_init).args |
| 166 | + kwargs.update(dict(zip(arg_names[1 : len(args) + 1], args))) |
| 167 | + |
| 168 | + # Explicitly serialize `dtype` to support auto_config |
| 169 | + dtype = kwargs.get("dtype", None) |
| 170 | + if dtype is not None and isinstance(dtype, dtype_policies.DTypePolicy): |
| 171 | + # For backward compatibility, we use a str (`name`) for |
| 172 | + # `DTypePolicy` |
| 173 | + if dtype.quantization_mode is None: |
| 174 | + kwargs["dtype"] = dtype.name |
| 175 | + # Otherwise, use `dtype_policies.serialize` |
| 176 | + else: |
| 177 | + kwargs["dtype"] = dtype_policies.serialize(dtype) |
| 178 | + |
| 179 | + # supported basic types |
| 180 | + supported_types = (str, int, float, bool, type(None)) |
| 181 | + |
| 182 | + flat_arg_values = tree.flatten(kwargs) |
| 183 | + auto_config = True |
| 184 | + for value in flat_arg_values: |
| 185 | + # adaptation: we allow all registered serializable objects |
| 186 | + is_serializable_object = ( |
| 187 | + isinstance(value, supported_types) |
| 188 | + or get_registered_object(get_registered_name(type(value))) is not None |
| 189 | + ) |
| 190 | + # adaptation: we allow all registered serializable objects |
| 191 | + try: |
| 192 | + is_serializable_class = inspect.isclass(value) and deserialize(serialize(value)) |
| 193 | + except ValueError: |
| 194 | + # deserializtion of type failed, probably not registered |
| 195 | + is_serializable_class = False |
| 196 | + if not (is_serializable_object or is_serializable_class): |
| 197 | + auto_config = False |
| 198 | + break |
| 199 | + |
| 200 | + if auto_config: |
| 201 | + with monkey_patch(keras.saving.serialize_keras_object, serialize): |
| 202 | + instance._auto_config = SerializableDict(**kwargs) |
| 203 | + else: |
| 204 | + instance._auto_config = None |
| 205 | + |
| 206 | + return wrapper |
| 207 | + |
| 208 | + cls.__init__ = init_decorator(cls.__init__) |
| 209 | + |
| 210 | + if hasattr(cls, "from_config") and cls.from_config.__func__ == keras.Layer.from_config.__func__: |
| 211 | + # By default, keras.Layer.from_config does not deserializte the config. For this class, there is a |
| 212 | + # from_config method that is identical to keras.Layer.config, so we replace it with a variant that applies |
| 213 | + # deserialization to the config. |
| 214 | + cls.from_config = classmethod(_deserializing_from_config) |
| 215 | + |
146 | 216 | # register subclasses as keras serializable |
147 | 217 | return keras.saving.register_keras_serializable(package=package, name=name)(cls) |
148 | 218 |
|
|
0 commit comments