Skip to content
Merged
Show file tree
Hide file tree
Changes from 67 commits
Commits
Show all changes
78 commits
Select commit Hold shift + click to select a range
cef797e
allow passing networks in flow matching
LarsKue Apr 8, 2025
1be06b0
move and improve find_distribution
LarsKue Apr 8, 2025
17d560d
improve find_network
LarsKue Apr 8, 2025
097e1fa
add improved serialization utils
LarsKue Apr 8, 2025
57b9764
update tests
LarsKue Apr 8, 2025
df21f67
show full keras traceback in tests
LarsKue Apr 9, 2025
253f49f
fix inference network fixture
LarsKue Apr 9, 2025
a523c9e
build manually in serialization test
LarsKue Apr 9, 2025
f70c339
fix serialization and build method in MLP
LarsKue Apr 9, 2025
ee0a327
fix serialization of flow matching w.r.t. building
LarsKue Apr 9, 2025
6883deb
fix old reference to LSTNet
LarsKue Apr 10, 2025
a812430
fix sample weight bug in FFF
LarsKue Apr 10, 2025
393275a
update keras_kwargs -> layer_kwargs, model_kwargs, sequential_kwargs
LarsKue Apr 10, 2025
731cb47
move to using keras.Model instead of keras.Layer
LarsKue Apr 10, 2025
8ba4517
improve stack traces in tests
LarsKue Apr 10, 2025
47354c9
fix generative_inference_network fixture
LarsKue Apr 10, 2025
9485885
add experimental MLP
LarsKue Apr 10, 2025
e00b752
add improved residual implementation
LarsKue Apr 10, 2025
e472835
add experimental resnet implementation
LarsKue Apr 10, 2025
bb6e373
add experimental fully-connected resnet implementation
LarsKue Apr 10, 2025
642a22b
improve default implementation of build method
LarsKue Apr 10, 2025
bebeb8d
improve docs
LarsKue Apr 10, 2025
8a27275
allow networks to be passed directly to FFF
LarsKue Apr 10, 2025
1cf30f8
small improvements
LarsKue Apr 10, 2025
0799408
always pad in double conv
LarsKue Apr 10, 2025
a54ca89
fix build signature in DoubleConv
LarsKue Apr 10, 2025
a12e1c6
fix MLP serialization
LarsKue Apr 10, 2025
fac117b
manually build in test_save_and_load
LarsKue Apr 10, 2025
eb3535c
fix FFF fixture
LarsKue Apr 10, 2025
c16bf80
fix FFF build method
LarsKue Apr 11, 2025
130ba54
move residual to networks
LarsKue Apr 11, 2025
8ecde7f
move new mlp to networks
LarsKue Apr 11, 2025
1e0e09c
allow passing networks directly in PointInferenceNetwork
LarsKue Apr 11, 2025
a3db509
fix test fixture sharing
LarsKue Apr 11, 2025
3347afc
clean up
LarsKue Apr 11, 2025
98a4e51
fix residual and mlp build for sets
LarsKue Apr 11, 2025
b58c832
improve mlp as discussed
LarsKue Apr 11, 2025
e92ac1c
fix build method for MLP for set and non-set compatibility
LarsKue Apr 11, 2025
9686deb
turn InvertibleLayer into Model
LarsKue Apr 11, 2025
e265505
move Residual into its own module
LarsKue Apr 11, 2025
66b30a9
re-add mlp names
LarsKue Apr 11, 2025
6130657
fix FFF build again
LarsKue Apr 11, 2025
48cfeca
improve comments
LarsKue Apr 11, 2025
099b102
filter base config in nets inheriting from Sequential
LarsKue Apr 11, 2025
e3fced4
improve config filtering in other models
LarsKue Apr 11, 2025
8629191
improve serialization of CouplingFlow and its Couplings
LarsKue Apr 11, 2025
ee33154
add MLP tests
LarsKue Apr 11, 2025
3472b9b
turn assert_layers_equal into assert_models_equal for better error me…
LarsKue Apr 11, 2025
a6db7cb
remove superfluous parametrize
LarsKue Apr 11, 2025
42ef8d5
improve assert_layers_equal
LarsKue Apr 11, 2025
03906fb
fix serialization name errors
LarsKue Apr 14, 2025
b236808
improve serialization of contiuous time consistency model
LarsKue Apr 14, 2025
3084c14
improve serialization of consistency model
LarsKue Apr 14, 2025
3afdb05
improve serialization of deep set
LarsKue Apr 14, 2025
752dd76
slight improvement to MLP serialization consistency
LarsKue Apr 14, 2025
ee27de9
improve serialization of time series network
LarsKue Apr 14, 2025
c611d86
improve serialization of transformers
LarsKue Apr 14, 2025
4def649
add default from_config to SummaryNetworks
LarsKue Apr 14, 2025
14280de
make serialization recursion collection check more robust
LarsKue Apr 14, 2025
e2199a9
remove assertion that layers must have variables
LarsKue Apr 14, 2025
8d5d81c
temporarily disable name check in assert_models_equal
LarsKue Apr 14, 2025
69f6e3f
implement get_config and from_config in TimeSeriesNetwork
LarsKue Apr 14, 2025
8ec3105
implement set_transformer from_config and get_config
LarsKue Apr 14, 2025
dc56d42
implement deep_set from_config and get_config
LarsKue Apr 14, 2025
a3437fa
fix serialization tests
LarsKue Apr 14, 2025
6e5e191
Merge branch 'dev' into allow-networks
LarsKue Apr 14, 2025
6cce1bc
update serialization for mamba
LarsKue Apr 14, 2025
6cd25c3
remove improved_mlp from experimental
LarsKue Apr 15, 2025
d825315
fix reassignment of activation in loop
LarsKue Apr 15, 2025
2d9be98
Big time revision
stefanradev93 Apr 16, 2025
d8ae2d1
Adapt notebooks and point inference serialization
stefanradev93 Apr 16, 2025
b206e5a
Revert point inference net
stefanradev93 Apr 16, 2025
b49e26c
use new serialization pipeline in adapter
LarsKue Apr 17, 2025
d07687f
add experimental single-thread stack-based monkey patching
LarsKue Apr 17, 2025
5a8d624
use monkey-patching to enable type deserialization
LarsKue Apr 17, 2025
b665830
grammar
LarsKue Apr 17, 2025
755f043
add serialization unit tests
LarsKue Apr 17, 2025
0bf125b
Merge branch 'dev' into allow-networks
LarsKue Apr 17, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions bayesflow/distributions/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,8 @@
from .diagonal_normal import DiagonalNormal
from .diagonal_student_t import DiagonalStudentT

from .find_distribution import find_distribution

from ..utils._docs import _add_imports_to_all

_add_imports_to_all(include_modules=[])
7 changes: 5 additions & 2 deletions bayesflow/distributions/distribution.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,12 @@
import keras

from bayesflow.types import Shape, Tensor
from bayesflow.utils import keras_kwargs
from bayesflow.utils import layer_kwargs


class Distribution(keras.Layer):
def __init__(self, **kwargs):
super().__init__(**keras_kwargs(kwargs))
super().__init__(**layer_kwargs(kwargs))

def call(self, samples: Tensor) -> Tensor:
return keras.ops.exp(self.log_prob(samples))
Expand All @@ -16,3 +16,6 @@

def sample(self, batch_shape: Shape) -> Tensor:
raise NotImplementedError

def compute_output_shape(self, input_shape: Shape) -> Shape:
return keras.ops.shape(self.sample(input_shape[0:1]))

Check warning on line 21 in bayesflow/distributions/distribution.py

View check run for this annotation

Codecov / codecov/patch

bayesflow/distributions/distribution.py#L21

Added line #L21 was not covered by tests
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
from functools import singledispatch

from bayesflow.distributions import Distribution


@singledispatch
def find_distribution(arg, **kwargs):
Expand All @@ -24,3 +26,8 @@ def _(name: str, *args, **kwargs):
@find_distribution.register
def _(none: None, *args, **kwargs):
return None


@find_distribution.register
def _(distribution: Distribution, *args, **kwargs):
return distribution
4 changes: 2 additions & 2 deletions bayesflow/experimental/cif/conditional_gaussian.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
from bayesflow.networks.mlp import MLP

from bayesflow.types import Shape, Tensor
from bayesflow.utils import keras_kwargs
from bayesflow.utils import layer_kwargs


@register_keras_serializable(package="bayesflow.networks.cif")
Expand Down Expand Up @@ -32,7 +32,7 @@
The MLP activation function
"""

super().__init__(**keras_kwargs(kwargs))
super().__init__(**layer_kwargs(kwargs))

Check warning on line 35 in bayesflow/experimental/cif/conditional_gaussian.py

View check run for this annotation

Codecov / codecov/patch

bayesflow/experimental/cif/conditional_gaussian.py#L35

Added line #L35 was not covered by tests
self.means = MLP([width] * depth, activation=activation)
self.stds = MLP([width] * depth, activation=activation)
self.output_projector = keras.layers.Dense(None)
Expand Down
72 changes: 44 additions & 28 deletions bayesflow/experimental/continuous_time_consistency_model.py
Original file line number Diff line number Diff line change
@@ -1,29 +1,28 @@
import keras
from keras import ops
from keras.saving import (
register_keras_serializable,
)

import numpy as np

import warnings

from bayesflow.networks import MLP
from bayesflow.types import Tensor
from bayesflow.utils import (
jvp,
concatenate_valid,
find_network,
keras_kwargs,
expand_right_as,
expand_right_to,
serialize_value_or_type,
deserialize_value_or_type,
model_kwargs,
)
from bayesflow.utils.serialization import deserialize, serializable, serialize


from bayesflow.networks import InferenceNetwork
from bayesflow.networks.embeddings import FourierEmbedding


@register_keras_serializable(package="bayesflow.networks")
@serializable
class ContinuousTimeConsistencyModel(InferenceNetwork):
"""Implements an sCM (simple, stable, and scalable Consistency Model)
with continous-time Consistency Training (CT) as described in [1].
Expand All @@ -40,8 +39,10 @@

def __init__(
self,
subnet: str | type = "mlp",
subnet: str | keras.Layer = "mlp",
sigma_data: float = 1.0,
subnet_kwargs: dict[str, any] = None,
embedding_kwargs: dict[str, any] = None,
**kwargs,
):
"""Creates an instance of an sCM to be used for consistency training (CT).
Expand All @@ -53,39 +54,52 @@
instantiated using subnet_kwargs.
sigma_data : float, optional, default: 1.0
Standard deviation of the target distribution
**kwargs : dict, optional, default: {}
Additional keyword arguments, such as
**kwargs
Additional keyword arguments to the layer.
"""
super().__init__(base_distribution="normal", **keras_kwargs(kwargs))
super().__init__(base_distribution="normal", **kwargs)

Check warning on line 60 in bayesflow/experimental/continuous_time_consistency_model.py

View check run for this annotation

Codecov / codecov/patch

bayesflow/experimental/continuous_time_consistency_model.py#L60

Added line #L60 was not covered by tests

if subnet_kwargs:
warnings.warn(

Check warning on line 63 in bayesflow/experimental/continuous_time_consistency_model.py

View check run for this annotation

Codecov / codecov/patch

bayesflow/experimental/continuous_time_consistency_model.py#L62-L63

Added lines #L62 - L63 were not covered by tests
"Using `subnet_kwargs` is deprecated."
"Instead, instantiate the network yourself and pass the arguments directly.",
DeprecationWarning,
)

subnet_kwargs = subnet_kwargs or {}

Check warning on line 69 in bayesflow/experimental/continuous_time_consistency_model.py

View check run for this annotation

Codecov / codecov/patch

bayesflow/experimental/continuous_time_consistency_model.py#L69

Added line #L69 was not covered by tests

self.subnet = find_network(subnet, **kwargs.get("subnet_kwargs", {}))
self.subnet_projector = keras.layers.Dense(units=None, bias_initializer="zeros", kernel_initializer="zeros")
self.subnet = find_network(subnet, **subnet_kwargs)
self.subnet_projector = keras.layers.Dense(

Check warning on line 72 in bayesflow/experimental/continuous_time_consistency_model.py

View check run for this annotation

Codecov / codecov/patch

bayesflow/experimental/continuous_time_consistency_model.py#L71-L72

Added lines #L71 - L72 were not covered by tests
units=None, bias_initializer="zeros", kernel_initializer="zeros", name="subnet_projector"
)

self.weight_fn = find_network("mlp", widths=(256,), dropout=0.0)
self.weight_fn_projector = keras.layers.Dense(units=1, bias_initializer="zeros", kernel_initializer="zeros")
self.weight_fn = MLP([256], dropout=0.0)
self.weight_fn_projector = keras.layers.Dense(

Check warning on line 77 in bayesflow/experimental/continuous_time_consistency_model.py

View check run for this annotation

Codecov / codecov/patch

bayesflow/experimental/continuous_time_consistency_model.py#L76-L77

Added lines #L76 - L77 were not covered by tests
units=1, bias_initializer="zeros", kernel_initializer="zeros", name="weight_fn_projector"
)

self.time_emb = FourierEmbedding(**kwargs.get("embedding_kwargs", {}))
embedding_kwargs = embedding_kwargs or {}
self.time_emb = FourierEmbedding(**embedding_kwargs)

Check warning on line 82 in bayesflow/experimental/continuous_time_consistency_model.py

View check run for this annotation

Codecov / codecov/patch

bayesflow/experimental/continuous_time_consistency_model.py#L81-L82

Added lines #L81 - L82 were not covered by tests
self.time_emb_dim = self.time_emb.embed_dim

self.sigma_data = sigma_data

self.seed_generator = keras.random.SeedGenerator()

# serialization: store all parameters necessary to call __init__
self.config = {
"sigma_data": sigma_data,
**kwargs,
}
self.config = serialize_value_or_type(self.config, "subnet", subnet)
@classmethod
def from_config(cls, config, custom_objects=None):
return cls(**deserialize(config, custom_objects=custom_objects))

Check warning on line 91 in bayesflow/experimental/continuous_time_consistency_model.py

View check run for this annotation

Codecov / codecov/patch

bayesflow/experimental/continuous_time_consistency_model.py#L91

Added line #L91 was not covered by tests

def get_config(self):
base_config = super().get_config()
return base_config | self.config
base_config = model_kwargs(base_config)

Check warning on line 95 in bayesflow/experimental/continuous_time_consistency_model.py

View check run for this annotation

Codecov / codecov/patch

bayesflow/experimental/continuous_time_consistency_model.py#L95

Added line #L95 was not covered by tests

@classmethod
def from_config(cls, config):
config = deserialize_value_or_type(config, "subnet")
return cls(**config)
config = {

Check warning on line 97 in bayesflow/experimental/continuous_time_consistency_model.py

View check run for this annotation

Codecov / codecov/patch

bayesflow/experimental/continuous_time_consistency_model.py#L97

Added line #L97 was not covered by tests
"subnet": self.subnet,
"sigma_data": self.sigma_data,
}

return base_config | serialize(config)

Check warning on line 102 in bayesflow/experimental/continuous_time_consistency_model.py

View check run for this annotation

Codecov / codecov/patch

bayesflow/experimental/continuous_time_consistency_model.py#L102

Added line #L102 was not covered by tests

def _discretize_time(self, num_steps: int, rho: float = 3.5, **kwargs):
t = np.linspace(0.0, np.pi / 2, num_steps)
Expand Down Expand Up @@ -206,7 +220,9 @@
out = ops.cos(t) * x - ops.sin(t) * self.sigma_data * f
return out

def compute_metrics(self, x: Tensor, conditions: Tensor = None, stage: str = "training") -> dict[str, Tensor]:
def compute_metrics(
self, x: Tensor, conditions: Tensor = None, stage: str = "training", **kwargs
) -> dict[str, Tensor]:
base_metrics = super().compute_metrics(x, conditions=conditions, stage=stage)

# $# Implements Algorithm 1 from [1]
Expand Down
110 changes: 63 additions & 47 deletions bayesflow/experimental/free_form_flow/free_form_flow.py
Original file line number Diff line number Diff line change
@@ -1,24 +1,25 @@
import keras
from keras import ops
from keras.saving import register_keras_serializable as serializable

import warnings

from bayesflow.distributions import Distribution
from bayesflow.types import Tensor
from bayesflow.utils import (
find_network,
keras_kwargs,
concatenate_valid,
find_network,
jacobian,
jvp,
model_kwargs,
vjp,
serialize_value_or_type,
deserialize_value_or_type,
weighted_mean,
)
from bayesflow.utils.serialization import deserialize, serializable, serialize

from bayesflow.networks import InferenceNetwork


@serializable(package="networks.free_form_flow")
@serializable
class FreeFormFlow(InferenceNetwork):
"""Implements a dimensionality-preserving Free-form Flow.
Incorporates ideas from [1-2].
Expand Down Expand Up @@ -53,10 +54,12 @@
def __init__(
self,
beta: float = 50.0,
encoder_subnet: str | type = "mlp",
decoder_subnet: str | type = "mlp",
base_distribution: str = "normal",
encoder_subnet: str | keras.Layer = "mlp",
decoder_subnet: str | keras.Layer = "mlp",
base_distribution: str | Distribution = "normal",
hutchinson_sampling: str = "qr",
encoder_subnet_kwargs: dict = None,
decoder_subnet_kwargs: dict = None,
**kwargs,
):
"""Creates an instance of a Free-form Flow.
Expand All @@ -80,54 +83,48 @@
**kwargs : dict, optional, default: {}
Additional keyword arguments
"""
super().__init__(base_distribution=base_distribution, **keras_kwargs(kwargs))
super().__init__(base_distribution, **kwargs)

if encoder_subnet == "mlp":
encoder_subnet_kwargs = FreeFormFlow.ENCODER_MLP_DEFAULT_CONFIG.copy()
encoder_subnet_kwargs.update(kwargs.get("encoder_subnet_kwargs", {}))
else:
encoder_subnet_kwargs = kwargs.get("encoder_subnet_kwargs", {})
if encoder_subnet_kwargs or decoder_subnet_kwargs:
warnings.warn(

Check warning on line 89 in bayesflow/experimental/free_form_flow/free_form_flow.py

View check run for this annotation

Codecov / codecov/patch

bayesflow/experimental/free_form_flow/free_form_flow.py#L89

Added line #L89 was not covered by tests
"Using `subnet_kwargs` is deprecated."
"Instead, instantiate the network yourself and pass the arguments directly.",
DeprecationWarning,
)

self.encoder_subnet = find_network(encoder_subnet, **encoder_subnet_kwargs)
self.encoder_projector = keras.layers.Dense(units=None, bias_initializer="zeros", kernel_initializer="zeros")
encoder_subnet_kwargs = encoder_subnet_kwargs or {}
decoder_subnet_kwargs = decoder_subnet_kwargs or {}

if encoder_subnet == "mlp":
encoder_subnet_kwargs = FreeFormFlow.ENCODER_MLP_DEFAULT_CONFIG.copy() | encoder_subnet_kwargs

Check warning on line 99 in bayesflow/experimental/free_form_flow/free_form_flow.py

View check run for this annotation

Codecov / codecov/patch

bayesflow/experimental/free_form_flow/free_form_flow.py#L99

Added line #L99 was not covered by tests

if decoder_subnet == "mlp":
decoder_subnet_kwargs = FreeFormFlow.DECODER_MLP_DEFAULT_CONFIG.copy()
decoder_subnet_kwargs.update(kwargs.get("decoder_subnet_kwargs", {}))
else:
decoder_subnet_kwargs = kwargs.get("decoder_subnet_kwargs", {})
decoder_subnet_kwargs = FreeFormFlow.DECODER_MLP_DEFAULT_CONFIG.copy() | decoder_subnet_kwargs

Check warning on line 102 in bayesflow/experimental/free_form_flow/free_form_flow.py

View check run for this annotation

Codecov / codecov/patch

bayesflow/experimental/free_form_flow/free_form_flow.py#L102

Added line #L102 was not covered by tests

self.encoder_subnet = find_network(encoder_subnet, **encoder_subnet_kwargs)
self.encoder_projector = keras.layers.Dense(
units=None, bias_initializer="zeros", kernel_initializer="zeros", name="encoder_projector"
)

self.decoder_subnet = find_network(decoder_subnet, **decoder_subnet_kwargs)
self.decoder_projector = keras.layers.Dense(units=None, bias_initializer="zeros", kernel_initializer="zeros")
self.decoder_projector = keras.layers.Dense(
units=None, bias_initializer="zeros", kernel_initializer="zeros", name="decoder_projector"
)

self.hutchinson_sampling = hutchinson_sampling
self.beta = beta

self.seed_generator = keras.random.SeedGenerator()

# serialization: store all parameters necessary to call __init__
self.config = {
"beta": beta,
"base_distribution": base_distribution,
"hutchinson_sampling": hutchinson_sampling,
**kwargs,
}
self.config = serialize_value_or_type(self.config, "encoder_subnet", encoder_subnet)
self.config = serialize_value_or_type(self.config, "decoder_subnet", decoder_subnet)

def get_config(self):
base_config = super().get_config()
return base_config | self.config

@classmethod
def from_config(cls, config):
config = deserialize_value_or_type(config, "encoder_subnet")
config = deserialize_value_or_type(config, "decoder_subnet")
return cls(**config)

# noinspection PyMethodOverriding
def build(self, xz_shape, conditions_shape=None):
super().build(xz_shape)
if self.built:
# building when the network is already built can cause issues with serialization
# see https://github.com/keras-team/keras/issues/21147
return

Check warning on line 124 in bayesflow/experimental/free_form_flow/free_form_flow.py

View check run for this annotation

Codecov / codecov/patch

bayesflow/experimental/free_form_flow/free_form_flow.py#L124

Added line #L124 was not covered by tests

self.base_distribution.build(xz_shape)

self.encoder_projector.units = xz_shape[-1]
self.decoder_projector.units = xz_shape[-1]

Expand All @@ -142,11 +139,11 @@
self.encoder_subnet.build(input_shape)
self.decoder_subnet.build(input_shape)

input_shape = self.encoder_subnet.compute_output_shape(input_shape)
self.encoder_projector.build(input_shape)
output_shape = self.encoder_subnet.compute_output_shape(input_shape)
self.encoder_projector.build(output_shape)

input_shape = self.decoder_subnet.compute_output_shape(input_shape)
self.decoder_projector.build(input_shape)
output_shape = self.decoder_subnet.compute_output_shape(input_shape)
self.decoder_projector.build(output_shape)

def _forward(
self, x: Tensor, conditions: Tensor = None, density: bool = False, training: bool = False, **kwargs
Expand Down Expand Up @@ -244,3 +241,22 @@
loss = weighted_mean(losses, sample_weight)

return base_metrics | {"loss": loss}

@classmethod
def from_config(cls, config, custom_objects=None):
return cls(**deserialize(config, custom_objects=custom_objects))

def get_config(self):
base_config = super().get_config()
base_config = model_kwargs(base_config)

config = {
"beta": self.beta,
"encoder_subnet": self.encoder_subnet,
"decoder_subnet": self.decoder_subnet,
"base_distribution": self.base_distribution,
"hutchinson_sampling": self.hutchinson_sampling,
# we do not need to store subnet_kwargs
}

return base_config | serialize(config)
1 change: 1 addition & 0 deletions bayesflow/experimental/improved_mlp/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
from .mlp import MLP

Check warning on line 1 in bayesflow/experimental/improved_mlp/__init__.py

View check run for this annotation

Codecov / codecov/patch

bayesflow/experimental/improved_mlp/__init__.py#L1

Added line #L1 was not covered by tests
Loading