Skip to content
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,12 @@ It provides users and researchers with:
BayesFlow (version 2+) is designed to be a flexible and efficient tool that enables rapid statistical inference
fueled by continuous progress in generative AI and Bayesian inference.

> [!IMPORTANT]
> As the 2.0 version introduced many new features, we still have to make breaking changes from time to time.
> This especially concerns **saving and loading** of models. We aim to stabilize this from the 2.1 release onwards.
> Until then, consider pinning your BayesFlow 2.0 installation to an exact version, or re-training after an update
> for less costly models.

## Important Note for Existing Users

You are currently looking at BayesFlow 2.0+, which is a complete rewrite of the library.
Expand Down
2 changes: 1 addition & 1 deletion bayesflow/adapters/adapter.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@
from .transforms.filter_transform import Predicate


@serializable
@serializable("bayesflow.adapters")
class Adapter(MutableSequence[Transform]):
"""
Defines an adapter to apply various transforms to data.
Expand Down
2 changes: 1 addition & 1 deletion bayesflow/adapters/transforms/as_set.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from .elementwise_transform import ElementwiseTransform


@serializable
@serializable("bayesflow.adapters")
class AsSet(ElementwiseTransform):
"""The `.as_set(["x", "y"])` transform indicates that both `x` and `y` are treated as sets.

Expand Down
2 changes: 1 addition & 1 deletion bayesflow/adapters/transforms/as_time_series.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from .elementwise_transform import ElementwiseTransform


@serializable
@serializable("bayesflow.adapters")
class AsTimeSeries(ElementwiseTransform):
"""The `.as_time_series` transform can be used to indicate that variables shall be treated as time series.

Expand Down
2 changes: 1 addition & 1 deletion bayesflow/adapters/transforms/broadcast.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from .transform import Transform


@serializable
@serializable("bayesflow.adapters")
class Broadcast(Transform):
"""
Broadcasts arrays or scalars to the shape of a given other array.
Expand Down
2 changes: 1 addition & 1 deletion bayesflow/adapters/transforms/concatenate.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
from .transform import Transform


@serializable
@serializable("bayesflow.adapters")
class Concatenate(Transform):
"""Concatenate multiple arrays into a new key. Used to specify how data variables should be treated by the network.

Expand Down
2 changes: 1 addition & 1 deletion bayesflow/adapters/transforms/constrain.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
from .elementwise_transform import ElementwiseTransform


@serializable
@serializable("bayesflow.adapters")
class Constrain(ElementwiseTransform):
"""
Constrains neural network predictions of a data variable to specified bounds.
Expand Down
2 changes: 1 addition & 1 deletion bayesflow/adapters/transforms/convert_dtype.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from .elementwise_transform import ElementwiseTransform


@serializable
@serializable("bayesflow.adapters")
class ConvertDType(ElementwiseTransform):
"""
Default transform used to convert all floats from float64 to float32 to be in line with keras framework.
Expand Down
2 changes: 1 addition & 1 deletion bayesflow/adapters/transforms/drop.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from .transform import Transform


@serializable
@serializable("bayesflow.adapters")
class Drop(Transform):
"""
Transform to drop variables from further calculation.
Expand Down
2 changes: 1 addition & 1 deletion bayesflow/adapters/transforms/elementwise_transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
from bayesflow.utils.serialization import serializable, deserialize


@serializable
@serializable("bayesflow.adapters")
class ElementwiseTransform:
"""Base class on which other transforms are based"""

Expand Down
2 changes: 1 addition & 1 deletion bayesflow/adapters/transforms/expand_dims.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from .elementwise_transform import ElementwiseTransform


@serializable
@serializable("bayesflow.adapters")
class ExpandDims(ElementwiseTransform):
"""
Expand the shape of an array.
Expand Down
2 changes: 1 addition & 1 deletion bayesflow/adapters/transforms/filter_transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ def __call__(self, key: str, value: np.ndarray, inverse: bool) -> bool:
raise NotImplementedError


@serializable
@serializable("bayesflow.adapters")
class FilterTransform(Transform):
"""
Implements a transform that applies a different transform on a subset of the data.
Expand Down
2 changes: 1 addition & 1 deletion bayesflow/adapters/transforms/keep.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from .transform import Transform


@serializable
@serializable("bayesflow.adapters")
class Keep(Transform):
"""
Name the data parameters that should be kept for futher calculation.
Expand Down
2 changes: 1 addition & 1 deletion bayesflow/adapters/transforms/log.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from .elementwise_transform import ElementwiseTransform


@serializable
@serializable("bayesflow.adapters")
class Log(ElementwiseTransform):
"""Log transforms a variable.

Expand Down
2 changes: 1 addition & 1 deletion bayesflow/adapters/transforms/map_transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from .transform import Transform


@serializable
@serializable("bayesflow.adapters")
class MapTransform(Transform):
"""
Implements a transform that applies a set of elementwise transforms
Expand Down
2 changes: 1 addition & 1 deletion bayesflow/adapters/transforms/numpy_transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from .elementwise_transform import ElementwiseTransform


@serializable
@serializable("bayesflow.adapters")
class NumpyTransform(ElementwiseTransform):
"""
A class to apply element-wise transformations using plain NumPy functions.
Expand Down
2 changes: 1 addition & 1 deletion bayesflow/adapters/transforms/one_hot.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from .elementwise_transform import ElementwiseTransform


@serializable
@serializable("bayesflow.adapters")
class OneHot(ElementwiseTransform):
"""
Changes data to be one-hot encoded.
Expand Down
2 changes: 1 addition & 1 deletion bayesflow/adapters/transforms/rename.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
from .transform import Transform


@serializable
@serializable("bayesflow.adapters")
class Rename(Transform):
"""
Transform to rename keys in data dictionary. Useful to rename variables to match those required by
Expand Down
2 changes: 1 addition & 1 deletion bayesflow/adapters/transforms/scale.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from .elementwise_transform import ElementwiseTransform


@serializable
@serializable("bayesflow.adapters")
class Scale(ElementwiseTransform):
def __init__(self, scale: np.typing.ArrayLike):
self.scale = np.array(scale)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,17 +2,18 @@
import numpy as np
from keras.saving import (
deserialize_keras_object as deserialize,
register_keras_serializable as serializable,
serialize_keras_object as serialize,
get_registered_name,
get_registered_object,
)

from bayesflow.utils.serialization import serializable
from .elementwise_transform import ElementwiseTransform
from ...utils import filter_kwargs
import inspect


@serializable(package="bayesflow.adapters")
@serializable("bayesflow.adapters")
class SerializableCustomTransform(ElementwiseTransform):
"""
Transforms a parameter using a pair of registered serializable forward and inverse functions.
Expand Down
2 changes: 1 addition & 1 deletion bayesflow/adapters/transforms/shift.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from .elementwise_transform import ElementwiseTransform


@serializable
@serializable("bayesflow.adapters")
class Shift(ElementwiseTransform):
def __init__(self, shift: np.typing.ArrayLike):
self.shift = np.array(shift)
Expand Down
2 changes: 1 addition & 1 deletion bayesflow/adapters/transforms/split.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from .transform import Transform


@serializable
@serializable("bayesflow.adapters")
class Split(Transform):
"""This is the effective inverse of the :py:class:`~Concatenate` Transform.

Expand Down
2 changes: 1 addition & 1 deletion bayesflow/adapters/transforms/sqrt.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from .elementwise_transform import ElementwiseTransform


@serializable
@serializable("bayesflow.adapters")
class Sqrt(ElementwiseTransform):
"""Square-root transform a variable.

Expand Down
2 changes: 1 addition & 1 deletion bayesflow/adapters/transforms/standardize.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
from .elementwise_transform import ElementwiseTransform


@serializable
@serializable("bayesflow.adapters")
class Standardize(ElementwiseTransform):
"""
Transform that when applied standardizes data using typical z-score standardization
Expand Down
2 changes: 1 addition & 1 deletion bayesflow/adapters/transforms/to_array.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
from .elementwise_transform import ElementwiseTransform


@serializable
@serializable("bayesflow.adapters")
class ToArray(ElementwiseTransform):
"""
Checks provided data for any non-arrays and converts them to numpy arrays.
Expand Down
2 changes: 1 addition & 1 deletion bayesflow/adapters/transforms/to_dict.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from .transform import Transform


@serializable
@serializable("bayesflow.adapters")
class ToDict(Transform):
"""Convert non-dict batches (e.g., pandas.DataFrame) to dict batches"""

Expand Down
2 changes: 1 addition & 1 deletion bayesflow/adapters/transforms/transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
from bayesflow.utils.serialization import serializable, deserialize


@serializable
@serializable("bayesflow.adapters")
class Transform:
"""
Base class on which other transforms are based
Expand Down
2 changes: 1 addition & 1 deletion bayesflow/approximators/continuous_approximator.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
from .approximator import Approximator


@serializable
@serializable("bayesflow.approximators")
class ContinuousApproximator(Approximator):
"""
Defines a workflow for performing fast posterior or likelihood inference.
Expand Down
2 changes: 1 addition & 1 deletion bayesflow/approximators/model_comparison_approximator.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
from .approximator import Approximator


@serializable
@serializable("bayesflow.approximators")
class ModelComparisonApproximator(Approximator):
"""
Defines an approximator for model (simulator) comparison, where the (discrete) posterior model probabilities are
Expand Down
2 changes: 1 addition & 1 deletion bayesflow/approximators/point_approximator.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
from .continuous_approximator import ContinuousApproximator


@serializable
@serializable("bayesflow.approximators")
class PointApproximator(ContinuousApproximator):
"""
A workflow for fast amortized point estimation of a conditional distribution.
Expand Down
2 changes: 1 addition & 1 deletion bayesflow/distributions/diagonal_normal.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
from .distribution import Distribution


@serializable
@serializable("bayesflow.distributions")
class DiagonalNormal(Distribution):
"""Implements a backend-agnostic diagonal Gaussian distribution."""

Expand Down
2 changes: 1 addition & 1 deletion bayesflow/distributions/diagonal_student_t.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
from .distribution import Distribution


@serializable
@serializable("bayesflow.distributions")
class DiagonalStudentT(Distribution):
"""Implements a backend-agnostic diagonal Student-t distribution."""

Expand Down
2 changes: 1 addition & 1 deletion bayesflow/distributions/distribution.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from bayesflow.utils.serialization import serializable, deserialize


@serializable
@serializable("bayesflow.distributions")
class Distribution(keras.Layer):
def __init__(self, **kwargs):
super().__init__(**layer_kwargs(kwargs))
Expand Down
2 changes: 1 addition & 1 deletion bayesflow/distributions/mixture.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
from bayesflow.distributions import Distribution


@serializable
@serializable("bayesflow.distributions")
class Mixture(Distribution):
"""Utility class for a backend-agnostic mixture distributions."""

Expand Down
5 changes: 3 additions & 2 deletions bayesflow/experimental/cif/cif.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,16 @@
import keras
from keras.saving import register_keras_serializable as serializable

from bayesflow.types import Shape, Tensor
from bayesflow.utils.serialization import serializable

from bayesflow.networks.inference_network import InferenceNetwork
from bayesflow.networks.coupling_flow import CouplingFlow

from .conditional_gaussian import ConditionalGaussian


@serializable(package="bayesflow.networks")
# disable module check, use potential module after moving from experimental
@serializable("bayesflow.networks", disable_module_check=True)
class CIF(InferenceNetwork):
"""Implements a continuously indexed flow (CIF) with a `CouplingFlow`
bijection and `ConditionalGaussian` distributions p and q. Improves on
Expand Down
5 changes: 3 additions & 2 deletions bayesflow/experimental/cif/conditional_gaussian.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,14 @@
import keras
from keras.saving import register_keras_serializable
import numpy as np
from bayesflow.networks.mlp import MLP

from bayesflow.types import Shape, Tensor
from bayesflow.utils import layer_kwargs
from bayesflow.utils.serialization import serializable


@register_keras_serializable(package="bayesflow.networks.cif")
# disable module check, use potential module after moving from experimental
@serializable("bayesflow.networks", disable_module_check=True)
class ConditionalGaussian(keras.Layer):
"""Implements a conditional gaussian distribution with neural networks for
the means and standard deviations respectively. Bulit in reference to [1].
Expand Down
3 changes: 2 additions & 1 deletion bayesflow/experimental/continuous_time_consistency_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,8 @@
from bayesflow.networks.embeddings import FourierEmbedding


@serializable
# disable module check, use potential module after moving from experimental
@serializable("bayesflow.networks", disable_module_check=True)
class ContinuousTimeConsistencyModel(InferenceNetwork):
"""Implements an sCM (simple, stable, and scalable Consistency Model)
with continous-time Consistency Training (CT) as described in [1].
Expand Down
3 changes: 2 additions & 1 deletion bayesflow/experimental/free_form_flow/free_form_flow.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,8 @@
from bayesflow.networks import InferenceNetwork


@serializable
# disable module check, use potential module after moving from experimental
@serializable("bayesflow.networks", disable_module_check=True)
class FreeFormFlow(InferenceNetwork):
"""Implements a dimensionality-preserving Free-form Flow.
Incorporates ideas from [1-2].
Expand Down
3 changes: 2 additions & 1 deletion bayesflow/experimental/resnet/dense_resnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,8 @@
from .double_linear import DoubleLinear


@serializable
# disable module check, use potential module after moving from experimental
@serializable("bayesflow.networks", disable_module_check=True)

Check warning on line 12 in bayesflow/experimental/resnet/dense_resnet.py

View check run for this annotation

Codecov / codecov/patch

bayesflow/experimental/resnet/dense_resnet.py#L12

Added line #L12 was not covered by tests
class DenseResNet(keras.Sequential):
"""
Implements the fully-connected analogue of the ResNet architecture.
Expand Down
3 changes: 2 additions & 1 deletion bayesflow/experimental/resnet/double_conv.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,8 @@
from bayesflow.utils.serialization import deserialize, serializable, serialize


@serializable
# disable module check, use potential module after moving from experimental
@serializable("bayesflow.networks", disable_module_check=True)

Check warning on line 9 in bayesflow/experimental/resnet/double_conv.py

View check run for this annotation

Codecov / codecov/patch

bayesflow/experimental/resnet/double_conv.py#L9

Added line #L9 was not covered by tests
class DoubleConv(keras.Sequential):
def __init__(
self,
Expand Down
Loading