Skip to content

Commit 1b4ba12

Browse files
committed
move auto-config from superclass to monkey-patched decorator
1 parent 51dff0d commit 1b4ba12

File tree

6 files changed

+79
-218
lines changed

6 files changed

+79
-218
lines changed

bayesflow/networks/base_layer/__init__.py

Lines changed: 0 additions & 1 deletion
This file was deleted.

bayesflow/networks/base_layer/base_layer.py

Lines changed: 0 additions & 210 deletions
This file was deleted.

bayesflow/networks/fusion_network/fusion_network.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
from collections.abc import Mapping
22
from ..summary_network import SummaryNetwork
3-
from bayesflow.utils.serialization import serializable, serialize
3+
from bayesflow.utils.serialization import deserialize, serializable, serialize
44
from bayesflow.types import Tensor, Shape
55
import keras
66
from keras import ops
@@ -116,3 +116,8 @@ def get_config(self) -> dict:
116116
"head": self.head,
117117
}
118118
return base_config | serialize(config)
119+
120+
@classmethod
121+
def from_config(cls, config: dict, custom_objects=None):
122+
config = deserialize(config, custom_objects=custom_objects)
123+
return cls(**config)

bayesflow/networks/inference_network.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,11 +5,10 @@
55
from bayesflow.utils import layer_kwargs, find_distribution
66
from bayesflow.utils.decorators import allow_batch_size
77
from bayesflow.utils.serialization import serializable
8-
from .base_layer import BaseLayer
98

109

1110
@serializable("bayesflow.networks")
12-
class InferenceNetwork(BaseLayer):
11+
class InferenceNetwork(keras.Layer):
1312
def __init__(self, base_distribution: str = "normal", *, metrics: Sequence[keras.Metric] = None, **kwargs):
1413
self.custom_metrics = metrics
1514
super().__init__(**layer_kwargs(kwargs))

bayesflow/networks/summary_network.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -6,19 +6,17 @@
66
from bayesflow.utils import layer_kwargs, find_distribution
77
from bayesflow.utils.decorators import sanitize_input_shape
88
from bayesflow.utils.serialization import serializable
9-
from .base_layer import BaseLayer
109

1110

1211
@serializable("bayesflow.networks")
13-
class SummaryNetwork(BaseLayer):
12+
class SummaryNetwork(keras.Layer):
1413
def __init__(self, base_distribution: str = None, *, metrics: Sequence[keras.Metric] = None, **kwargs):
1514
self.custom_metrics = metrics
1615
super().__init__(**layer_kwargs(kwargs))
1716
self.base_distribution = find_distribution(base_distribution)
1817

1918
@sanitize_input_shape
2019
def build(self, input_shape):
21-
print("SN build", self, input_shape)
2220
x = keras.ops.zeros(input_shape)
2321
z = self.call(x)
2422

bayesflow/utils/serialization.py

Lines changed: 71 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,12 +3,16 @@
33
import builtins
44
import inspect
55
import keras
6+
import functools
67
import numpy as np
78
import sys
89
from warnings import warn
910

1011
# this import needs to be exactly like this to work with monkey patching
11-
from keras.saving import deserialize_keras_object
12+
from keras.saving import deserialize_keras_object, get_registered_object, get_registered_name
13+
from keras.src.saving.serialization_lib import SerializableDict
14+
from keras import dtype_policies
15+
from keras import tree
1216

1317
from .context_managers import monkey_patch
1418
from .decorators import allow_args
@@ -95,6 +99,10 @@ def deserialize(config: dict, custom_objects=None, safe_mode=True, **kwargs):
9599
return obj
96100

97101

102+
def _deserializing_from_config(cls, config, custom_objects=None):
103+
return cls(**deserialize(config, custom_objects=custom_objects))
104+
105+
98106
@allow_args
99107
def serializable(cls, package: str, name: str | None = None, disable_module_check: bool = False):
100108
"""Register class as Keras serializable.
@@ -143,6 +151,68 @@ def serializable(cls, package: str, name: str | None = None, disable_module_chec
143151
if name is None:
144152
name = copy(cls.__name__)
145153

154+
def init_decorator(original_init):
155+
# Adds auto-config behavior after the __init__ function. This extends the auto-config capabilities provided
156+
# by keras.Operation (base class of keras.Layer) with support for all serializable objects.
157+
# This produces a serialized config that has to be deserialized properly, see below.
158+
@functools.wraps(original_init)
159+
def wrapper(instance, *args, **kwargs):
160+
original_init(instance, *args, **kwargs)
161+
162+
# Generate a config to be returned by default by `get_config()`.
163+
# Adapted from keras.Operation.
164+
kwargs = kwargs.copy()
165+
arg_names = inspect.getfullargspec(original_init).args
166+
kwargs.update(dict(zip(arg_names[1 : len(args) + 1], args)))
167+
168+
# Explicitly serialize `dtype` to support auto_config
169+
dtype = kwargs.get("dtype", None)
170+
if dtype is not None and isinstance(dtype, dtype_policies.DTypePolicy):
171+
# For backward compatibility, we use a str (`name`) for
172+
# `DTypePolicy`
173+
if dtype.quantization_mode is None:
174+
kwargs["dtype"] = dtype.name
175+
# Otherwise, use `dtype_policies.serialize`
176+
else:
177+
kwargs["dtype"] = dtype_policies.serialize(dtype)
178+
179+
# supported basic types
180+
supported_types = (str, int, float, bool, type(None))
181+
182+
flat_arg_values = tree.flatten(kwargs)
183+
auto_config = True
184+
for value in flat_arg_values:
185+
# adaptation: we allow all registered serializable objects
186+
is_serializable_object = (
187+
isinstance(value, supported_types)
188+
or get_registered_object(get_registered_name(type(value))) is not None
189+
)
190+
# adaptation: we allow all registered serializable objects
191+
try:
192+
is_serializable_class = inspect.isclass(value) and deserialize(serialize(value))
193+
except ValueError:
194+
# deserializtion of type failed, probably not registered
195+
is_serializable_class = False
196+
if not (is_serializable_object or is_serializable_class):
197+
auto_config = False
198+
break
199+
200+
if auto_config:
201+
with monkey_patch(keras.saving.serialize_keras_object, serialize):
202+
instance._auto_config = SerializableDict(**kwargs)
203+
else:
204+
instance._auto_config = None
205+
206+
return wrapper
207+
208+
cls.__init__ = init_decorator(cls.__init__)
209+
210+
if hasattr(cls, "from_config") and cls.from_config.__func__ == keras.Layer.from_config.__func__:
211+
# By default, keras.Layer.from_config does not deserializte the config. For this class, there is a
212+
# from_config method that is identical to keras.Layer.config, so we replace it with a variant that applies
213+
# deserialization to the config.
214+
cls.from_config = classmethod(_deserializing_from_config)
215+
146216
# register subclasses as keras serializable
147217
return keras.saving.register_keras_serializable(package=package, name=name)(cls)
148218

0 commit comments

Comments
 (0)