Skip to content

Commit 662614e

Browse files
committed
deprecate (de)serialize_value_or_type
- replaced by serialize and deserialize - adapt classes in the point inference network module which relied on the old structure - add more lenient tests to compare configs - adapt developer docs
1 parent a1b4d19 commit 662614e

File tree

10 files changed

+70
-125
lines changed

10 files changed

+70
-125
lines changed

bayesflow/links/ordered.py

Lines changed: 12 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,8 @@
11
import keras
2-
from keras.saving import register_keras_serializable as serializable
32

43
from bayesflow.utils import layer_kwargs
54
from bayesflow.utils.decorators import sanitize_input_shape
5+
from bayesflow.utils.serialization import deserialize, serializable, serialize
66

77

88
@serializable(package="links.ordered")
@@ -14,12 +14,20 @@ def __init__(self, axis: int, anchor_index: int, **kwargs):
1414
self.axis = axis
1515
self.anchor_index = anchor_index
1616
self.group_indices = None
17-
18-
self.config = {"axis": axis, "anchor_index": anchor_index, **kwargs}
17+
self._kwargs = kwargs
1918

2019
def get_config(self):
2120
base_config = super().get_config()
22-
return base_config | self.config
21+
config = {
22+
"axis": self.axis,
23+
"anchor_index": self.anchor_index,
24+
**self._kwargs,
25+
}
26+
return base_config | serialize(config)
27+
28+
@classmethod
29+
def from_config(cls, config, custom_objects=None):
30+
return cls(**deserialize(config, custom_objects=custom_objects))
2331

2432
def build(self, input_shape):
2533
super().build(input_shape)

bayesflow/links/ordered_quantiles.py

Lines changed: 5 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
import keras
2-
from keras.saving import register_keras_serializable as serializable
32

43
from bayesflow.utils import layer_kwargs, logging
4+
from bayesflow.utils.serialization import serializable
55

66
from collections.abc import Sequence
77

@@ -16,14 +16,12 @@ def __init__(self, q: Sequence[float] = None, axis: int = None, **kwargs):
1616
super().__init__(axis, None, **layer_kwargs(kwargs))
1717
self.q = q
1818

19-
self.config = {
20-
"q": q,
21-
"axis": axis,
22-
}
23-
2419
def get_config(self):
2520
base_config = super().get_config()
26-
return base_config | self.config
21+
config = {
22+
"q": self.q,
23+
}
24+
return base_config | config
2725

2826
def build(self, input_shape):
2927
if self.axis is None and 1 < len(input_shape) <= 3:

bayesflow/networks/point_inference_network.py

Lines changed: 10 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,7 @@
11
import keras
2-
from keras.saving import (
3-
deserialize_keras_object as deserialize,
4-
serialize_keras_object as serialize,
5-
register_keras_serializable as serializable,
6-
)
72

8-
from bayesflow.utils import model_kwargs, find_network, serialize_value_or_type, deserialize_value_or_type
3+
from bayesflow.utils import model_kwargs, find_network
4+
from bayesflow.utils.serialization import deserialize, serializable, serialize
95
from bayesflow.types import Shape, Tensor
106
from bayesflow.scores import ScoringRule, ParametricDistributionScore
117
from bayesflow.utils.decorators import allow_batch_size
@@ -26,14 +22,9 @@ def __init__(
2622
super().__init__(**model_kwargs(kwargs))
2723

2824
self.scores = scores
29-
3025
self.subnet = find_network(subnet, **kwargs.get("subnet_kwargs", {}))
3126

32-
self.config = {
33-
**kwargs,
34-
}
35-
self.config = serialize_value_or_type(self.config, "subnet", subnet)
36-
self.config["scores"] = serialize(self.scores)
27+
self._kwargs = kwargs
3728

3829
def build(self, xz_shape: Shape, conditions_shape: Shape = None) -> None:
3930
"""Builds all network components based on shapes of conditions and targets.
@@ -112,15 +103,18 @@ def build_from_config(self, config):
112103

113104
def get_config(self):
114105
base_config = super().get_config()
106+
config = {
107+
"scores": self.scores,
108+
"subnet": self.subnet,
109+
**self._kwargs,
110+
}
115111

116-
return base_config | self.config
112+
return base_config | serialize(config)
117113

118114
@classmethod
119115
def from_config(cls, config):
120116
config = config.copy()
121-
config["scores"] = deserialize(config["scores"])
122-
config = deserialize_value_or_type(config, "subnet")
123-
return cls(**config)
117+
return cls(**deserialize(config))
124118

125119
def call(
126120
self,

bayesflow/scores/quantile_score.py

Lines changed: 10 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,11 @@
11
from typing import Sequence
22

33
import keras
4-
from keras.saving import register_keras_serializable as serializable
54

65
from bayesflow.types import Shape, Tensor
76
from bayesflow.utils import logging, weighted_mean
87
from bayesflow.links import OrderedQuantiles
8+
from bayesflow.utils.serialization import deserialize, serializable, serialize
99

1010
from .scoring_rule import ScoringRule
1111

@@ -31,13 +31,17 @@ def __init__(self, q: Sequence[float] = None, links=None, **kwargs):
3131
self._q = keras.ops.convert_to_tensor(q, dtype="float32")
3232
self.links = links or {"value": OrderedQuantiles(q=q)}
3333

34-
self.config = {
35-
"q": q,
36-
}
37-
3834
def get_config(self):
3935
base_config = super().get_config()
40-
return base_config | self.config
36+
config = {
37+
"q": self.q,
38+
"links": self.links,
39+
}
40+
return base_config | serialize(config)
41+
42+
@classmethod
43+
def from_config(cls, config, custom_objects=None):
44+
return cls(**deserialize(config, custom_objects=custom_objects))
4145

4246
def get_head_shapes_from_target_shape(self, target_shape: Shape) -> dict[str, tuple]:
4347
# keras.saving.load_model sometimes passes target_shape as a list, so we force a conversion

bayesflow/scores/scoring_rule.py

Lines changed: 7 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,10 @@
11
import math
22

33
import keras
4-
from keras.saving import register_keras_serializable as serializable
54

65
from bayesflow.types import Shape, Tensor
7-
from bayesflow.utils import find_network, serialize_value_or_type, deserialize_value_or_type
6+
from bayesflow.utils import find_network
7+
from bayesflow.utils.serialization import deserialize, serializable, serialize
88

99

1010
@serializable(package="bayesflow.scores")
@@ -48,28 +48,17 @@ def __init__(
4848
self.subnets_kwargs = subnets_kwargs or {}
4949
self.links = links or {}
5050

51-
self.config = {"subnets_kwargs": self.subnets_kwargs}
52-
5351
def get_config(self):
54-
self.config["subnets"] = {
55-
key: serialize_value_or_type({}, "subnet", subnet) for key, subnet in self.subnets.items()
52+
config = {
53+
"subnets": self.subnets,
54+
"links": self.links,
5655
}
57-
self.config["links"] = {key: serialize_value_or_type({}, "link", link) for key, link in self.links.items()}
5856

59-
return self.config
57+
return serialize(config)
6058

6159
@classmethod
6260
def from_config(cls, config):
63-
config = config.copy()
64-
config["subnets"] = {
65-
key: deserialize_value_or_type(subnet_dict, "subnet")["subnet"]
66-
for key, subnet_dict in config["subnets"].items()
67-
}
68-
config["links"] = {
69-
key: deserialize_value_or_type(link_dict, "link")["link"] for key, link_dict in config["links"].items()
70-
}
71-
72-
return cls(**config)
61+
return cls(**deserialize(config))
7362

7463
def get_head_shapes_from_target_shape(self, target_shape: Shape) -> dict[str, Shape]:
7564
"""Request a dictionary of names and output shapes of required heads from the score."""

bayesflow/utils/serialization.py

Lines changed: 13 additions & 68 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
import keras
66
import numpy as np
77
import sys
8+
from warnings import warn
89

910
# this import needs to be exactly like this to work with monkey patching
1011
from keras.saving import deserialize_keras_object
@@ -19,77 +20,21 @@
1920

2021

2122
def serialize_value_or_type(config, name, obj):
22-
"""Serialize an object that can be either a value or a type
23-
and add it to a copy of the supplied dictionary.
24-
25-
Parameters
26-
----------
27-
config : dict
28-
Dictionary to add the serialized object to. This function does not
29-
modify the dictionary in place, but returns a modified copy.
30-
name : str
31-
Name of the obj that should be stored. Required for later deserialization.
32-
obj : object or type
33-
The object to serialize. If `obj` is of type `type`, we use
34-
`keras.saving.get_registered_name` to obtain the registered type name.
35-
If it is not a type, we try to serialize it as a Keras object.
36-
37-
Returns
38-
-------
39-
updated_config : dict
40-
Updated dictionary with a new key `"_bayesflow_<name>_type"` or
41-
`"_bayesflow_<name>_val"`. The prefix is used to avoid name collisions,
42-
the suffix indicates how the stored value has to be deserialized.
43-
44-
Notes
45-
-----
46-
We allow strings or `type` parameters at several places to instantiate objects
47-
of a given type (e.g., `subnet` in `CouplingFlow`). As `type` objects cannot
48-
be serialized, we have to distinguish the two cases for serialization and
49-
deserialization. This function is a helper function to standardize and
50-
simplify this.
51-
"""
52-
updated_config = config.copy()
53-
if isinstance(obj, type):
54-
updated_config[f"{PREFIX}{name}_type"] = keras.saving.get_registered_name(obj)
55-
else:
56-
updated_config[f"{PREFIX}{name}_val"] = keras.saving.serialize_keras_object(obj)
57-
return updated_config
23+
"""This function is deprecated."""
24+
warn(
25+
"This method is deprecated. It was replaced by bayesflow.utils.serialization.serialize.",
26+
DeprecationWarning,
27+
stacklevel=2,
28+
)
5829

5930

6031
def deserialize_value_or_type(config, name):
61-
"""Deserialize an object that can be either a value or a type and add
62-
it to the supplied dictionary.
63-
64-
Parameters
65-
----------
66-
config : dict
67-
Dictionary containing the object to deserialize. If a type was
68-
serialized, it should contain the key `"_bayesflow_<name>_type"`.
69-
If an object was serialized, it should contain the key
70-
`"_bayesflow_<name>_val"`. In a copy of this dictionary,
71-
the item will be replaced with the key `name`.
72-
name : str
73-
Name of the object to deserialize.
74-
75-
Returns
76-
-------
77-
updated_config : dict
78-
Updated dictionary with a new key `name`, with a value that is either
79-
a type or an object.
80-
81-
See Also
82-
--------
83-
serialize_value_or_type
84-
"""
85-
updated_config = config.copy()
86-
if f"{PREFIX}{name}_type" in config:
87-
updated_config[name] = keras.saving.get_registered_object(config[f"{PREFIX}{name}_type"])
88-
del updated_config[f"{PREFIX}{name}_type"]
89-
elif f"{PREFIX}{name}_val" in config:
90-
updated_config[name] = keras.saving.deserialize_keras_object(config[f"{PREFIX}{name}_val"])
91-
del updated_config[f"{PREFIX}{name}_val"]
92-
return updated_config
32+
"""This function is deprecated."""
33+
warn(
34+
"This method is deprecated. It was replaced by bayesflow.utils.serialization.deserialize.",
35+
DeprecationWarning,
36+
stacklevel=2,
37+
)
9338

9439

9540
def deserialize(config: dict, custom_objects=None, safe_mode=True, **kwargs):

docsrc/source/development/serialization.md

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -21,8 +21,6 @@ We mainly provide three convenience functions:
2121
- The {py:func}`~bayesflow.utils.serialization.serialize` function, which adds support for serializing classes.
2222
- Its counterpart {py:func}`~bayesflow.utils.serialization.deserialize`, adds support to deserialize classes.
2323

24-
_Note: The `(de)serialize_value_or_type` functions are made obsolete by the functions given above and will probably be deprecated soon._
25-
2624
## Usage
2725

2826
To use the adapted serialization functions, you have to use them in the `get_config` and `from_config` method. Please refer to existing classes in the library for usage examples.

tests/test_networks/test_point_inference_network/test_point_inference_network.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
deserialize_keras_object as deserialize,
44
serialize_keras_object as serialize,
55
)
6-
from tests.utils import assert_layers_equal
6+
from tests.utils import assert_layers_equal, assert_configs_equal
77
import pytest
88

99

@@ -71,8 +71,7 @@ def test_save_and_load_quantile(tmp_path, quantile_point_inference_network, rand
7171
keras.saving.save_model(net, tmp_path / "model.keras")
7272
loaded = keras.saving.load_model(tmp_path / "model.keras")
7373

74-
print(net.get_config())
75-
assert net.get_config() == loaded.get_config()
74+
assert_configs_equal(net.get_config(), loaded.get_config())
7675

7776
assert_layers_equal(net, loaded)
7877

tests/test_scores/test_scores.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,6 @@ def test_score_output(scoring_rule, random_conditions):
1717

1818
# Using random random_conditions also as targets for the purpose of this test.
1919
head_shapes = scoring_rule.get_head_shapes_from_target_shape(random_conditions.shape)
20-
print(scoring_rule.get_config())
2120
estimates = {}
2221
for key, output_shape in head_shapes.items():
2322
link = scoring_rule.get_link(key)

tests/utils/assertions.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -35,3 +35,14 @@ def assert_layers_equal(layer1: keras.Layer, layer2: keras.Layer):
3535
# this is turned off for now, see https://github.com/bayesflow-org/bayesflow/issues/412
3636
msg = f"Layers {layer1.name} and {layer2.name} have a different name."
3737
# assert layer1.name == layer2.name, msg
38+
39+
40+
def assert_configs_equal(config1, config2):
41+
"""Asserts that two configs are equal.
42+
43+
Ignores whether lists or tuples were used, a difference that sometimes arises for
44+
the `input_shape` in `build_config` entries
45+
"""
46+
config1 = keras.tree.lists_to_tuples(config1)
47+
config2 = keras.tree.lists_to_tuples(config2)
48+
assert config1 == config2

0 commit comments

Comments
 (0)