Skip to content

Commit 147dd2d

Browse files
authored
🚀 Merge pull request #413 from bayesflow-org/allow-networks
Improve serialization and allow networks to be passed directly to most models
2 parents 8482926 + 0bf125b commit 147dd2d

File tree

93 files changed

+2392
-2148
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

93 files changed

+2392
-2148
lines changed

bayesflow/adapters/adapter.py

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -2,11 +2,7 @@
22

33
import numpy as np
44

5-
from keras.saving import (
6-
deserialize_keras_object as deserialize,
7-
register_keras_serializable as serializable,
8-
serialize_keras_object as serialize,
9-
)
5+
from bayesflow.utils.serialization import deserialize, serialize, serializable
106

117
from .transforms import (
128
AsSet,
@@ -33,7 +29,7 @@
3329
from .transforms.filter_transform import Predicate
3430

3531

36-
@serializable(package="bayesflow.adapters")
32+
@serializable
3733
class Adapter(MutableSequence[Transform]):
3834
"""
3935
Defines an adapter to apply various transforms to data.
@@ -74,10 +70,14 @@ def create_default(inference_variables: Sequence[str]) -> "Adapter":
7470

7571
@classmethod
7672
def from_config(cls, config: dict, custom_objects=None) -> "Adapter":
77-
return cls(transforms=deserialize(config["transforms"], custom_objects))
73+
return cls(**deserialize(config, custom_objects=custom_objects))
7874

7975
def get_config(self) -> dict:
80-
return {"transforms": serialize(self.transforms)}
76+
config = {
77+
"transforms": self.transforms,
78+
}
79+
80+
return serialize(config)
8181

8282
def forward(self, data: dict[str, any], **kwargs) -> dict[str, np.ndarray]:
8383
"""Apply the transforms in the forward direction.

bayesflow/adapters/transforms/as_set.py

Lines changed: 3 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,11 @@
1-
from keras.saving import register_keras_serializable as serializable
21
import numpy as np
32

3+
from bayesflow.utils.serialization import serializable
4+
45
from .elementwise_transform import ElementwiseTransform
56

67

7-
@serializable(package="bayesflow.adapters")
8+
@serializable
89
class AsSet(ElementwiseTransform):
910
"""The `.as_set(["x", "y"])` transform indicates that both `x` and `y` are treated as sets.
1011
@@ -33,9 +34,5 @@ def inverse(self, data: np.ndarray, **kwargs) -> np.ndarray:
3334

3435
return data
3536

36-
@classmethod
37-
def from_config(cls, config: dict, custom_objects=None) -> "AsSet":
38-
return cls()
39-
4037
def get_config(self) -> dict:
4138
return {}

bayesflow/adapters/transforms/as_time_series.py

Lines changed: 3 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,11 @@
11
import numpy as np
2-
from keras.saving import register_keras_serializable as serializable
2+
3+
from bayesflow.utils.serialization import serializable
34

45
from .elementwise_transform import ElementwiseTransform
56

67

7-
@serializable(package="bayesflow.adapters")
8+
@serializable
89
class AsTimeSeries(ElementwiseTransform):
910
"""The `.as_time_series` transform can be used to indicate that variables shall be treated as time series.
1011
@@ -29,9 +30,5 @@ def inverse(self, data: np.ndarray, **kwargs) -> np.ndarray:
2930

3031
return data
3132

32-
@classmethod
33-
def from_config(cls, config: dict, custom_objects=None) -> "AsTimeSeries":
34-
return cls()
35-
3633
def get_config(self) -> dict:
3734
return {}

bayesflow/adapters/transforms/broadcast.py

Lines changed: 9 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -1,16 +1,12 @@
11
from collections.abc import Sequence
22
import numpy as np
33

4-
from keras.saving import (
5-
deserialize_keras_object as deserialize,
6-
register_keras_serializable as serializable,
7-
serialize_keras_object as serialize,
8-
)
4+
from bayesflow.utils.serialization import serialize, serializable
95

106
from .transform import Transform
117

128

13-
@serializable(package="bayesflow.adapters")
9+
@serializable
1410
class Broadcast(Transform):
1511
"""
1612
Broadcasts arrays or scalars to the shape of a given other array.
@@ -96,31 +92,15 @@ def __init__(
9692
self.exclude = exclude
9793
self.squeeze = squeeze
9894

99-
@classmethod
100-
def from_config(cls, config: dict, custom_objects=None) -> "Broadcast":
101-
# Deserialize turns tuples to lists, undo it if necessary
102-
exclude = deserialize(config["exclude"], custom_objects)
103-
exclude = tuple(exclude) if isinstance(exclude, list) else exclude
104-
expand = deserialize(config["expand"], custom_objects)
105-
expand = tuple(expand) if isinstance(expand, list) else expand
106-
squeeze = deserialize(config["squeeze"], custom_objects)
107-
squeeze = tuple(squeeze) if isinstance(squeeze, list) else squeeze
108-
return cls(
109-
keys=deserialize(config["keys"], custom_objects),
110-
to=deserialize(config["to"], custom_objects),
111-
expand=expand,
112-
exclude=exclude,
113-
squeeze=squeeze,
114-
)
115-
11695
def get_config(self) -> dict:
117-
return {
118-
"keys": serialize(self.keys),
119-
"to": serialize(self.to),
120-
"expand": serialize(self.expand),
121-
"exclude": serialize(self.exclude),
122-
"squeeze": serialize(self.squeeze),
96+
config = {
97+
"keys": self.keys,
98+
"to": self.to,
99+
"expand": self.expand,
100+
"exclude": self.exclude,
101+
"squeeze": self.squeeze,
123102
}
103+
return serialize(config)
124104

125105
# noinspection PyMethodOverriding
126106
def forward(self, data: dict[str, np.ndarray], **kwargs) -> dict[str, np.ndarray]:

bayesflow/adapters/transforms/concatenate.py

Lines changed: 11 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -1,16 +1,13 @@
11
from collections.abc import Sequence
22

33
import numpy as np
4-
from keras.saving import (
5-
deserialize_keras_object as deserialize,
6-
register_keras_serializable as serializable,
7-
serialize_keras_object as serialize,
8-
)
4+
5+
from bayesflow.utils.serialization import serialize, serializable
96

107
from .transform import Transform
118

129

13-
@serializable(package="bayesflow.adapters")
10+
@serializable
1411
class Concatenate(Transform):
1512
"""Concatenate multiple arrays into a new key. Used to specify how data variables should be treated by the network.
1613
@@ -35,29 +32,21 @@ class Concatenate(Transform):
3532
)
3633
"""
3734

38-
def __init__(self, keys: Sequence[str], *, into: str, axis: int = -1, _indices: list | None = None):
35+
def __init__(self, keys: Sequence[str], *, into: str, axis: int = -1, indices: list | None = None):
3936
self.keys = keys
4037
self.into = into
4138
self.axis = axis
4239

43-
self.indices = _indices
44-
45-
@classmethod
46-
def from_config(cls, config: dict, custom_objects=None) -> "Concatenate":
47-
return cls(
48-
keys=deserialize(config["keys"], custom_objects),
49-
into=deserialize(config["into"], custom_objects),
50-
axis=deserialize(config["axis"], custom_objects),
51-
_indices=deserialize(config["indices"], custom_objects),
52-
)
40+
self.indices = indices
5341

5442
def get_config(self) -> dict:
55-
return {
56-
"keys": serialize(self.keys),
57-
"into": serialize(self.into),
58-
"axis": serialize(self.axis),
59-
"indices": serialize(self.indices),
43+
config = {
44+
"keys": self.keys,
45+
"into": self.into,
46+
"axis": self.axis,
47+
"indices": self.indices,
6048
}
49+
return serialize(config)
6150

6251
def forward(self, data: dict[str, any], *, strict: bool = True, **kwargs) -> dict[str, any]:
6352
if not strict and self.indices is None:

bayesflow/adapters/transforms/constrain.py

Lines changed: 4 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,6 @@
1-
from keras.saving import (
2-
register_keras_serializable as serializable,
3-
)
41
import numpy as np
52

3+
from bayesflow.utils.serialization import serializable, serialize
64
from bayesflow.utils.numpy_utils import (
75
inverse_sigmoid,
86
inverse_softplus,
@@ -13,7 +11,7 @@
1311
from .elementwise_transform import ElementwiseTransform
1412

1513

16-
@serializable(package="bayesflow.adapters")
14+
@serializable
1715
class Constrain(ElementwiseTransform):
1816
"""
1917
Constrains neural network predictions of a data variable to specified bounds.
@@ -163,18 +161,15 @@ def unconstrain(x):
163161
case other:
164162
raise ValueError(f"Unsupported value for 'inclusive': {other!r}.")
165163

166-
@classmethod
167-
def from_config(cls, config: dict, custom_objects=None) -> "Constrain":
168-
return cls(**config)
169-
170164
def get_config(self) -> dict:
171-
return {
165+
config = {
172166
"lower": self.lower,
173167
"upper": self.upper,
174168
"method": self.method,
175169
"inclusive": self.inclusive,
176170
"epsilon": self.epsilon,
177171
}
172+
return serialize(config)
178173

179174
def forward(self, data: np.ndarray, **kwargs) -> np.ndarray:
180175
# forward means data space -> network space, so unconstrain the data
Lines changed: 9 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,11 @@
1-
from keras.saving import (
2-
deserialize_keras_object as deserialize,
3-
register_keras_serializable as serializable,
4-
serialize_keras_object as serialize,
5-
)
61
import numpy as np
72

3+
from bayesflow.utils.serialization import serializable, serialize
4+
85
from .elementwise_transform import ElementwiseTransform
96

107

11-
@serializable(package="bayesflow.adapters")
8+
@serializable
129
class ConvertDType(ElementwiseTransform):
1310
"""
1411
Default transform used to convert all floats from float64 to float32 to be in line with keras framework.
@@ -27,21 +24,15 @@ def __init__(self, from_dtype: str, to_dtype: str):
2724
self.from_dtype = from_dtype
2825
self.to_dtype = to_dtype
2926

30-
@classmethod
31-
def from_config(cls, config: dict, custom_objects=None) -> "ConvertDType":
32-
return cls(
33-
from_dtype=deserialize(config["from_dtype"], custom_objects),
34-
to_dtype=deserialize(config["to_dtype"], custom_objects),
35-
)
36-
3727
def get_config(self) -> dict:
38-
return {
39-
"from_dtype": serialize(self.from_dtype),
40-
"to_dtype": serialize(self.to_dtype),
28+
config = {
29+
"from_dtype": self.from_dtype,
30+
"to_dtype": self.to_dtype,
4131
}
32+
return serialize(config)
4233

4334
def forward(self, data: np.ndarray, **kwargs) -> np.ndarray:
44-
return data.astype(self.to_dtype)
35+
return data.astype(self.to_dtype, copy=False)
4536

4637
def inverse(self, data: np.ndarray, **kwargs) -> np.ndarray:
47-
return data.astype(self.from_dtype)
38+
return data.astype(self.from_dtype, copy=False)

bayesflow/adapters/transforms/drop.py

Lines changed: 3 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1,15 +1,11 @@
11
from collections.abc import Sequence
22

3-
from keras.saving import (
4-
deserialize_keras_object as deserialize,
5-
register_keras_serializable as serializable,
6-
serialize_keras_object as serialize,
7-
)
3+
from bayesflow.utils.serialization import serializable, serialize
84

95
from .transform import Transform
106

117

12-
@serializable(package="bayesflow.adapters")
8+
@serializable
139
class Drop(Transform):
1410
"""
1511
Transform to drop variables from further calculation.
@@ -37,12 +33,8 @@ class Drop(Transform):
3733
def __init__(self, keys: Sequence[str]):
3834
self.keys = keys
3935

40-
@classmethod
41-
def from_config(cls, config: dict, custom_objects=None) -> "Drop":
42-
return cls(keys=deserialize(config["keys"], custom_objects))
43-
4436
def get_config(self) -> dict:
45-
return {"keys": serialize(self.keys)}
37+
return serialize({"keys": self.keys})
4638

4739
def forward(self, data: dict[str, any], **kwargs) -> dict[str, any]:
4840
# no strict version because there is no requirement for the keys to be present

bayesflow/adapters/transforms/elementwise_transform.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,9 @@
1-
from keras.saving import register_keras_serializable as serializable
21
import numpy as np
32

3+
from bayesflow.utils.serialization import serializable, deserialize
44

5-
@serializable(package="bayesflow.adapters")
5+
6+
@serializable
67
class ElementwiseTransform:
78
"""Base class on which other transforms are based"""
89

@@ -13,8 +14,8 @@ def __call__(self, data: np.ndarray, inverse: bool = False, **kwargs) -> np.ndar
1314
return self.forward(data, **kwargs)
1415

1516
@classmethod
16-
def from_config(cls, config: dict, custom_objects=None) -> "ElementwiseTransform":
17-
raise NotImplementedError
17+
def from_config(cls, config: dict, custom_objects=None):
18+
return cls(**deserialize(config, custom_objects=custom_objects))
1819

1920
def get_config(self) -> dict:
2021
raise NotImplementedError

bayesflow/adapters/transforms/expand_dims.py

Lines changed: 4 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,11 @@
11
import numpy as np
2-
from keras.saving import (
3-
deserialize_keras_object as deserialize,
4-
register_keras_serializable as serializable,
5-
serialize_keras_object as serialize,
6-
)
2+
3+
from bayesflow.utils.serialization import serializable, serialize
74

85
from .elementwise_transform import ElementwiseTransform
96

107

11-
@serializable(package="bayesflow.adapters")
8+
@serializable
129
class ExpandDims(ElementwiseTransform):
1310
"""
1411
Expand the shape of an array.
@@ -51,16 +48,8 @@ def __init__(self, *, axis: int | tuple):
5148
super().__init__()
5249
self.axis = axis
5350

54-
@classmethod
55-
def from_config(cls, config: dict, custom_objects=None) -> "ExpandDims":
56-
return cls(
57-
axis=deserialize(config["axis"], custom_objects),
58-
)
59-
6051
def get_config(self) -> dict:
61-
return {
62-
"axis": serialize(self.axis),
63-
}
52+
return serialize({"axis": self.axis})
6453

6554
def forward(self, data: np.ndarray, **kwargs) -> np.ndarray:
6655
return np.expand_dims(data, axis=self.axis)

0 commit comments

Comments
 (0)