Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
78 commits
Select commit Hold shift + click to select a range
cef797e
allow passing networks in flow matching
LarsKue Apr 8, 2025
1be06b0
move and improve find_distribution
LarsKue Apr 8, 2025
17d560d
improve find_network
LarsKue Apr 8, 2025
097e1fa
add improved serialization utils
LarsKue Apr 8, 2025
57b9764
update tests
LarsKue Apr 8, 2025
df21f67
show full keras traceback in tests
LarsKue Apr 9, 2025
253f49f
fix inference network fixture
LarsKue Apr 9, 2025
a523c9e
build manually in serialization test
LarsKue Apr 9, 2025
f70c339
fix serialization and build method in MLP
LarsKue Apr 9, 2025
ee0a327
fix serialization of flow matching w.r.t. building
LarsKue Apr 9, 2025
6883deb
fix old reference to LSTNet
LarsKue Apr 10, 2025
a812430
fix sample weight bug in FFF
LarsKue Apr 10, 2025
393275a
update keras_kwargs -> layer_kwargs, model_kwargs, sequential_kwargs
LarsKue Apr 10, 2025
731cb47
move to using keras.Model instead of keras.Layer
LarsKue Apr 10, 2025
8ba4517
improve stack traces in tests
LarsKue Apr 10, 2025
47354c9
fix generative_inference_network fixture
LarsKue Apr 10, 2025
9485885
add experimental MLP
LarsKue Apr 10, 2025
e00b752
add improved residual implementation
LarsKue Apr 10, 2025
e472835
add experimental resnet implementation
LarsKue Apr 10, 2025
bb6e373
add experimental fully-connected resnet implementation
LarsKue Apr 10, 2025
642a22b
improve default implementation of build method
LarsKue Apr 10, 2025
bebeb8d
improve docs
LarsKue Apr 10, 2025
8a27275
allow networks to be passed directly to FFF
LarsKue Apr 10, 2025
1cf30f8
small improvements
LarsKue Apr 10, 2025
0799408
always pad in double conv
LarsKue Apr 10, 2025
a54ca89
fix build signature in DoubleConv
LarsKue Apr 10, 2025
a12e1c6
fix MLP serialization
LarsKue Apr 10, 2025
fac117b
manually build in test_save_and_load
LarsKue Apr 10, 2025
eb3535c
fix FFF fixture
LarsKue Apr 10, 2025
c16bf80
fix FFF build method
LarsKue Apr 11, 2025
130ba54
move residual to networks
LarsKue Apr 11, 2025
8ecde7f
move new mlp to networks
LarsKue Apr 11, 2025
1e0e09c
allow passing networks directly in PointInferenceNetwork
LarsKue Apr 11, 2025
a3db509
fix test fixture sharing
LarsKue Apr 11, 2025
3347afc
clean up
LarsKue Apr 11, 2025
98a4e51
fix residual and mlp build for sets
LarsKue Apr 11, 2025
b58c832
improve mlp as discussed
LarsKue Apr 11, 2025
e92ac1c
fix build method for MLP for set and non-set compatibility
LarsKue Apr 11, 2025
9686deb
turn InvertibleLayer into Model
LarsKue Apr 11, 2025
e265505
move Residual into its own module
LarsKue Apr 11, 2025
66b30a9
re-add mlp names
LarsKue Apr 11, 2025
6130657
fix FFF build again
LarsKue Apr 11, 2025
48cfeca
improve comments
LarsKue Apr 11, 2025
099b102
filter base config in nets inheriting from Sequential
LarsKue Apr 11, 2025
e3fced4
improve config filtering in other models
LarsKue Apr 11, 2025
8629191
improve serialization of CouplingFlow and its Couplings
LarsKue Apr 11, 2025
ee33154
add MLP tests
LarsKue Apr 11, 2025
3472b9b
turn assert_layers_equal into assert_models_equal for better error me…
LarsKue Apr 11, 2025
a6db7cb
remove superfluous parametrize
LarsKue Apr 11, 2025
42ef8d5
improve assert_layers_equal
LarsKue Apr 11, 2025
03906fb
fix serialization name errors
LarsKue Apr 14, 2025
b236808
improve serialization of contiuous time consistency model
LarsKue Apr 14, 2025
3084c14
improve serialization of consistency model
LarsKue Apr 14, 2025
3afdb05
improve serialization of deep set
LarsKue Apr 14, 2025
752dd76
slight improvement to MLP serialization consistency
LarsKue Apr 14, 2025
ee27de9
improve serialization of time series network
LarsKue Apr 14, 2025
c611d86
improve serialization of transformers
LarsKue Apr 14, 2025
4def649
add default from_config to SummaryNetworks
LarsKue Apr 14, 2025
14280de
make serialization recursion collection check more robust
LarsKue Apr 14, 2025
e2199a9
remove assertion that layers must have variables
LarsKue Apr 14, 2025
8d5d81c
temporarily disable name check in assert_models_equal
LarsKue Apr 14, 2025
69f6e3f
implement get_config and from_config in TimeSeriesNetwork
LarsKue Apr 14, 2025
8ec3105
implement set_transformer from_config and get_config
LarsKue Apr 14, 2025
dc56d42
implement deep_set from_config and get_config
LarsKue Apr 14, 2025
a3437fa
fix serialization tests
LarsKue Apr 14, 2025
6e5e191
Merge branch 'dev' into allow-networks
LarsKue Apr 14, 2025
6cce1bc
update serialization for mamba
LarsKue Apr 14, 2025
6cd25c3
remove improved_mlp from experimental
LarsKue Apr 15, 2025
d825315
fix reassignment of activation in loop
LarsKue Apr 15, 2025
2d9be98
Big time revision
stefanradev93 Apr 16, 2025
d8ae2d1
Adapt notebooks and point inference serialization
stefanradev93 Apr 16, 2025
b206e5a
Revert point inference net
stefanradev93 Apr 16, 2025
b49e26c
use new serialization pipeline in adapter
LarsKue Apr 17, 2025
d07687f
add experimental single-thread stack-based monkey patching
LarsKue Apr 17, 2025
5a8d624
use monkey-patching to enable type deserialization
LarsKue Apr 17, 2025
b665830
grammar
LarsKue Apr 17, 2025
755f043
add serialization unit tests
LarsKue Apr 17, 2025
0bf125b
Merge branch 'dev' into allow-networks
LarsKue Apr 17, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 8 additions & 8 deletions bayesflow/adapters/adapter.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,7 @@

import numpy as np

from keras.saving import (
deserialize_keras_object as deserialize,
register_keras_serializable as serializable,
serialize_keras_object as serialize,
)
from bayesflow.utils.serialization import deserialize, serialize, serializable

from .transforms import (
AsSet,
Expand All @@ -33,7 +29,7 @@
from .transforms.filter_transform import Predicate


@serializable(package="bayesflow.adapters")
@serializable
class Adapter(MutableSequence[Transform]):
"""
Defines an adapter to apply various transforms to data.
Expand Down Expand Up @@ -74,10 +70,14 @@ def create_default(inference_variables: Sequence[str]) -> "Adapter":

@classmethod
def from_config(cls, config: dict, custom_objects=None) -> "Adapter":
return cls(transforms=deserialize(config["transforms"], custom_objects))
return cls(**deserialize(config, custom_objects=custom_objects))

def get_config(self) -> dict:
return {"transforms": serialize(self.transforms)}
config = {
"transforms": self.transforms,
}

return serialize(config)

def forward(self, data: dict[str, any], **kwargs) -> dict[str, np.ndarray]:
"""Apply the transforms in the forward direction.
Expand Down
9 changes: 3 additions & 6 deletions bayesflow/adapters/transforms/as_set.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,11 @@
from keras.saving import register_keras_serializable as serializable
import numpy as np

from bayesflow.utils.serialization import serializable

from .elementwise_transform import ElementwiseTransform


@serializable(package="bayesflow.adapters")
@serializable
class AsSet(ElementwiseTransform):
"""The `.as_set(["x", "y"])` transform indicates that both `x` and `y` are treated as sets.

Expand Down Expand Up @@ -33,9 +34,5 @@ def inverse(self, data: np.ndarray, **kwargs) -> np.ndarray:

return data

@classmethod
def from_config(cls, config: dict, custom_objects=None) -> "AsSet":
return cls()

def get_config(self) -> dict:
return {}
9 changes: 3 additions & 6 deletions bayesflow/adapters/transforms/as_time_series.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,11 @@
import numpy as np
from keras.saving import register_keras_serializable as serializable

from bayesflow.utils.serialization import serializable

from .elementwise_transform import ElementwiseTransform


@serializable(package="bayesflow.adapters")
@serializable
class AsTimeSeries(ElementwiseTransform):
"""The `.as_time_series` transform can be used to indicate that variables shall be treated as time series.

Expand All @@ -29,9 +30,5 @@ def inverse(self, data: np.ndarray, **kwargs) -> np.ndarray:

return data

@classmethod
def from_config(cls, config: dict, custom_objects=None) -> "AsTimeSeries":
return cls()

def get_config(self) -> dict:
return {}
38 changes: 9 additions & 29 deletions bayesflow/adapters/transforms/broadcast.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,12 @@
from collections.abc import Sequence
import numpy as np

from keras.saving import (
deserialize_keras_object as deserialize,
register_keras_serializable as serializable,
serialize_keras_object as serialize,
)
from bayesflow.utils.serialization import serialize, serializable

from .transform import Transform


@serializable(package="bayesflow.adapters")
@serializable
class Broadcast(Transform):
"""
Broadcasts arrays or scalars to the shape of a given other array.
Expand Down Expand Up @@ -96,31 +92,15 @@ def __init__(
self.exclude = exclude
self.squeeze = squeeze

@classmethod
def from_config(cls, config: dict, custom_objects=None) -> "Broadcast":
# Deserialize turns tuples to lists, undo it if necessary
exclude = deserialize(config["exclude"], custom_objects)
exclude = tuple(exclude) if isinstance(exclude, list) else exclude
expand = deserialize(config["expand"], custom_objects)
expand = tuple(expand) if isinstance(expand, list) else expand
squeeze = deserialize(config["squeeze"], custom_objects)
squeeze = tuple(squeeze) if isinstance(squeeze, list) else squeeze
return cls(
keys=deserialize(config["keys"], custom_objects),
to=deserialize(config["to"], custom_objects),
expand=expand,
exclude=exclude,
squeeze=squeeze,
)

def get_config(self) -> dict:
return {
"keys": serialize(self.keys),
"to": serialize(self.to),
"expand": serialize(self.expand),
"exclude": serialize(self.exclude),
"squeeze": serialize(self.squeeze),
config = {
"keys": self.keys,
"to": self.to,
"expand": self.expand,
"exclude": self.exclude,
"squeeze": self.squeeze,
}
return serialize(config)

# noinspection PyMethodOverriding
def forward(self, data: dict[str, np.ndarray], **kwargs) -> dict[str, np.ndarray]:
Expand Down
33 changes: 11 additions & 22 deletions bayesflow/adapters/transforms/concatenate.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,13 @@
from collections.abc import Sequence

import numpy as np
from keras.saving import (
deserialize_keras_object as deserialize,
register_keras_serializable as serializable,
serialize_keras_object as serialize,
)

from bayesflow.utils.serialization import serialize, serializable

from .transform import Transform


@serializable(package="bayesflow.adapters")
@serializable
class Concatenate(Transform):
"""Concatenate multiple arrays into a new key. Used to specify how data variables should be treated by the network.

Expand All @@ -35,29 +32,21 @@ class Concatenate(Transform):
)
"""

def __init__(self, keys: Sequence[str], *, into: str, axis: int = -1, _indices: list | None = None):
def __init__(self, keys: Sequence[str], *, into: str, axis: int = -1, indices: list | None = None):
self.keys = keys
self.into = into
self.axis = axis

self.indices = _indices

@classmethod
def from_config(cls, config: dict, custom_objects=None) -> "Concatenate":
return cls(
keys=deserialize(config["keys"], custom_objects),
into=deserialize(config["into"], custom_objects),
axis=deserialize(config["axis"], custom_objects),
_indices=deserialize(config["indices"], custom_objects),
)
self.indices = indices

def get_config(self) -> dict:
return {
"keys": serialize(self.keys),
"into": serialize(self.into),
"axis": serialize(self.axis),
"indices": serialize(self.indices),
config = {
"keys": self.keys,
"into": self.into,
"axis": self.axis,
"indices": self.indices,
}
return serialize(config)

def forward(self, data: dict[str, any], *, strict: bool = True, **kwargs) -> dict[str, any]:
if not strict and self.indices is None:
Expand Down
13 changes: 4 additions & 9 deletions bayesflow/adapters/transforms/constrain.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,6 @@
from keras.saving import (
register_keras_serializable as serializable,
)
import numpy as np

from bayesflow.utils.serialization import serializable, serialize
from bayesflow.utils.numpy_utils import (
inverse_sigmoid,
inverse_softplus,
Expand All @@ -13,7 +11,7 @@
from .elementwise_transform import ElementwiseTransform


@serializable(package="bayesflow.adapters")
@serializable
class Constrain(ElementwiseTransform):
"""
Constrains neural network predictions of a data variable to specified bounds.
Expand Down Expand Up @@ -163,18 +161,15 @@ def unconstrain(x):
case other:
raise ValueError(f"Unsupported value for 'inclusive': {other!r}.")

@classmethod
def from_config(cls, config: dict, custom_objects=None) -> "Constrain":
return cls(**config)

def get_config(self) -> dict:
return {
config = {
"lower": self.lower,
"upper": self.upper,
"method": self.method,
"inclusive": self.inclusive,
"epsilon": self.epsilon,
}
return serialize(config)

def forward(self, data: np.ndarray, **kwargs) -> np.ndarray:
# forward means data space -> network space, so unconstrain the data
Expand Down
27 changes: 9 additions & 18 deletions bayesflow/adapters/transforms/convert_dtype.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,11 @@
from keras.saving import (
deserialize_keras_object as deserialize,
register_keras_serializable as serializable,
serialize_keras_object as serialize,
)
import numpy as np

from bayesflow.utils.serialization import serializable, serialize

from .elementwise_transform import ElementwiseTransform


@serializable(package="bayesflow.adapters")
@serializable
class ConvertDType(ElementwiseTransform):
"""
Default transform used to convert all floats from float64 to float32 to be in line with keras framework.
Expand All @@ -27,21 +24,15 @@ def __init__(self, from_dtype: str, to_dtype: str):
self.from_dtype = from_dtype
self.to_dtype = to_dtype

@classmethod
def from_config(cls, config: dict, custom_objects=None) -> "ConvertDType":
return cls(
from_dtype=deserialize(config["from_dtype"], custom_objects),
to_dtype=deserialize(config["to_dtype"], custom_objects),
)

def get_config(self) -> dict:
return {
"from_dtype": serialize(self.from_dtype),
"to_dtype": serialize(self.to_dtype),
config = {
"from_dtype": self.from_dtype,
"to_dtype": self.to_dtype,
}
return serialize(config)

def forward(self, data: np.ndarray, **kwargs) -> np.ndarray:
return data.astype(self.to_dtype)
return data.astype(self.to_dtype, copy=False)

def inverse(self, data: np.ndarray, **kwargs) -> np.ndarray:
return data.astype(self.from_dtype)
return data.astype(self.from_dtype, copy=False)
14 changes: 3 additions & 11 deletions bayesflow/adapters/transforms/drop.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,11 @@
from collections.abc import Sequence

from keras.saving import (
deserialize_keras_object as deserialize,
register_keras_serializable as serializable,
serialize_keras_object as serialize,
)
from bayesflow.utils.serialization import serializable, serialize

from .transform import Transform


@serializable(package="bayesflow.adapters")
@serializable
class Drop(Transform):
"""
Transform to drop variables from further calculation.
Expand Down Expand Up @@ -37,12 +33,8 @@ class Drop(Transform):
def __init__(self, keys: Sequence[str]):
self.keys = keys

@classmethod
def from_config(cls, config: dict, custom_objects=None) -> "Drop":
return cls(keys=deserialize(config["keys"], custom_objects))

def get_config(self) -> dict:
return {"keys": serialize(self.keys)}
return serialize({"keys": self.keys})

def forward(self, data: dict[str, any], **kwargs) -> dict[str, any]:
# no strict version because there is no requirement for the keys to be present
Expand Down
9 changes: 5 additions & 4 deletions bayesflow/adapters/transforms/elementwise_transform.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
from keras.saving import register_keras_serializable as serializable
import numpy as np

from bayesflow.utils.serialization import serializable, deserialize

@serializable(package="bayesflow.adapters")

@serializable
class ElementwiseTransform:
"""Base class on which other transforms are based"""

Expand All @@ -13,8 +14,8 @@ def __call__(self, data: np.ndarray, inverse: bool = False, **kwargs) -> np.ndar
return self.forward(data, **kwargs)

@classmethod
def from_config(cls, config: dict, custom_objects=None) -> "ElementwiseTransform":
raise NotImplementedError
def from_config(cls, config: dict, custom_objects=None):
return cls(**deserialize(config, custom_objects=custom_objects))

def get_config(self) -> dict:
raise NotImplementedError
Expand Down
19 changes: 4 additions & 15 deletions bayesflow/adapters/transforms/expand_dims.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,11 @@
import numpy as np
from keras.saving import (
deserialize_keras_object as deserialize,
register_keras_serializable as serializable,
serialize_keras_object as serialize,
)

from bayesflow.utils.serialization import serializable, serialize

from .elementwise_transform import ElementwiseTransform


@serializable(package="bayesflow.adapters")
@serializable
class ExpandDims(ElementwiseTransform):
"""
Expand the shape of an array.
Expand Down Expand Up @@ -51,16 +48,8 @@ def __init__(self, *, axis: int | tuple):
super().__init__()
self.axis = axis

@classmethod
def from_config(cls, config: dict, custom_objects=None) -> "ExpandDims":
return cls(
axis=deserialize(config["axis"], custom_objects),
)

def get_config(self) -> dict:
return {
"axis": serialize(self.axis),
}
return serialize({"axis": self.axis})

def forward(self, data: np.ndarray, **kwargs) -> np.ndarray:
return np.expand_dims(data, axis=self.axis)
Expand Down
Loading