diff --git a/README.md b/README.md
index c4049c470..9b5ef3d47 100644
--- a/README.md
+++ b/README.md
@@ -15,6 +15,25 @@ It provides users and researchers with:
BayesFlow (version 2+) is designed to be a flexible and efficient tool that enables rapid statistical inference
fueled by continuous progress in generative AI and Bayesian inference.
+> [!IMPORTANT]
+> As the 2.0 version introduced many new features, we still have to make breaking changes from time to time.
+> This especially concerns **saving and loading** of models. We aim to stabilize this from the 2.1 release onwards.
+> Until then, consider pinning your BayesFlow 2.0 installation to an exact version, or re-training after an update
+> for less costly models.
+
+## Important Note for Existing Users
+
+You are currently looking at BayesFlow 2.0+, which is a complete rewrite of the library.
+While it shares the same overall goals with the 1.x versions, the API is not compatible.
+
+> [!CAUTION]
+> A few features, most notably hierarchical models, have not been ported to BayesFlow 2.0+
+> yet. We are working on those features and plan to add them soon. You can find the complete
+> list in the [FAQ](#faq) below.
+
+The [Moving from BayesFlow v1.1 to v2.0](examples/From_BayesFlow_1.1_to_2.0.ipynb) guide
+highlights how concepts and classes relate between the two versions.
+
## Conceptual Overview
@@ -216,11 +235,48 @@ while the old version was based on TensorFlow.
-------------
+**Question:**
+Should I switch to BayesFlow 2.0+ now? Are there features that are still missing?
+
+**Answer:**
+In general, we recommend to switch, as the new version is easier to use and will continue
+to receive improvements and new features. However, a few features are still missing, so you
+might want to wait until everything you need has been ported to BayesFlow 2.0+.
+
+Depending on your needs, you might not want to upgrade yet if one of the following applies:
+
+- You have an ongoing project that uses BayesFlow 1.x, and you do not want to allocate
+ time for migrating it to the new API.
+- You have already trained models in BayesFlow 1.x, that you do not want to re-train
+ with the new version. Loading models from version 1.x in version 2.0+ is not supported.
+- You require a feature that was not ported to BayesFlow 2.0+ yet. To our knowledge,
+ this applies to:
+ * Two-level/Hierarchical models (planned for version 2.1): `TwoLevelGenerativeModel`, `TwoLevelPrior`.
+ * Sensitivity analysis (partially discontinued): functionality from the `bayesflow.sensitivity` module. This is still
+ possible, but we do no longer offer a special module for it. We plan to add a tutorial on this, see [#455](https://github.com/bayesflow-org/bayesflow/issues/455).
+ * MCMC (discontinued): The `bayesflow.mcmc` module. We are considering other options
+ to enable the use of BayesFlow in an MCMC setting.
+ * Networks: `EvidentialNetwork`.
+ * Model misspecification detection: MMD test in the summary space (see #384).
+
+If you encounter any functionality that is missing and not listed here, please let us
+know by opening an issue.
+
+-------------
+
**Question:**
I still need the old BayesFlow for some of my projects. How can I install it?
**Answer:**
You can find and install the old Bayesflow version via the `stable-legacy` branch on GitHub.
+The corresponding [documentation](https://bayesflow.org/stable-legacy/index.html) can be
+accessed by selecting the "stable-legacy" entry in the version picker of the documentation.
+
+You can also install the latest version of BayesFlow v1.x from PyPI using
+
+```
+pip install "bayesflow<2.0"
+```
-------------
diff --git a/bayesflow/__init__.py b/bayesflow/__init__.py
index 5a28ffe2e..008afc89b 100644
--- a/bayesflow/__init__.py
+++ b/bayesflow/__init__.py
@@ -50,17 +50,11 @@ def setup():
"in contexts where you need gradients (e.g. custom training loops)."
)
+ # dynamically add __version__ attribute
+ from importlib.metadata import version
-# dynamically add version dunder variable
-try:
- from importlib.metadata import version, PackageNotFoundError
+ globals()["__version__"] = version("bayesflow")
- __version__ = version(__name__)
-except PackageNotFoundError:
- __version__ = "2.0.0"
-finally:
- del version
- del PackageNotFoundError
# call and clean up namespace
setup()
diff --git a/bayesflow/adapters/adapter.py b/bayesflow/adapters/adapter.py
index 4db738eef..a17a59d81 100644
--- a/bayesflow/adapters/adapter.py
+++ b/bayesflow/adapters/adapter.py
@@ -29,7 +29,7 @@
from .transforms.filter_transform import Predicate
-@serializable
+@serializable("bayesflow.adapters")
class Adapter(MutableSequence[Transform]):
"""
Defines an adapter to apply various transforms to data.
@@ -79,7 +79,9 @@ def get_config(self) -> dict:
return serialize(config)
- def forward(self, data: dict[str, any], *, stage: str = "inference", **kwargs) -> dict[str, np.ndarray]:
+ def forward(
+ self, data: dict[str, any], *, stage: str = "inference", log_det_jac: bool = False, **kwargs
+ ) -> dict[str, np.ndarray] | tuple[dict[str, np.ndarray], dict[str, np.ndarray]]:
"""Apply the transforms in the forward direction.
Parameters
@@ -88,22 +90,33 @@ def forward(self, data: dict[str, any], *, stage: str = "inference", **kwargs) -
The data to be transformed.
stage : str, one of ["training", "validation", "inference"]
The stage the function is called in.
+ log_det_jac: bool, optional
+ Whether to return the log determinant of the Jacobian of the transforms.
**kwargs : dict
Additional keyword arguments passed to each transform.
Returns
-------
- dict
- The transformed data.
+ dict | tuple[dict, dict]
+ The transformed data or tuple of transformed data and log determinant of the Jacobian.
"""
data = data.copy()
+ if not log_det_jac:
+ for transform in self.transforms:
+ data = transform(data, stage=stage, **kwargs)
+ return data
+ log_det_jac = {}
for transform in self.transforms:
- data = transform(data, stage=stage, **kwargs)
+ transformed_data = transform(data, stage=stage, **kwargs)
+ log_det_jac = transform.log_det_jac(data, log_det_jac, **kwargs)
+ data = transformed_data
- return data
+ return data, log_det_jac
- def inverse(self, data: dict[str, np.ndarray], *, stage: str = "inference", **kwargs) -> dict[str, any]:
+ def inverse(
+ self, data: dict[str, np.ndarray], *, stage: str = "inference", log_det_jac: bool = False, **kwargs
+ ) -> dict[str, np.ndarray] | tuple[dict[str, np.ndarray], dict[str, np.ndarray]]:
"""Apply the transforms in the inverse direction.
Parameters
@@ -112,24 +125,32 @@ def inverse(self, data: dict[str, np.ndarray], *, stage: str = "inference", **kw
The data to be transformed.
stage : str, one of ["training", "validation", "inference"]
The stage the function is called in.
+ log_det_jac: bool, optional
+ Whether to return the log determinant of the Jacobian of the transforms.
**kwargs : dict
Additional keyword arguments passed to each transform.
Returns
-------
- dict
- The transformed data.
+ dict | tuple[dict, dict]
+ The transformed data or tuple of transformed data and log determinant of the Jacobian.
"""
data = data.copy()
+ if not log_det_jac:
+ for transform in reversed(self.transforms):
+ data = transform(data, stage=stage, inverse=True, **kwargs)
+ return data
+ log_det_jac = {}
for transform in reversed(self.transforms):
data = transform(data, stage=stage, inverse=True, **kwargs)
+ log_det_jac = transform.log_det_jac(data, log_det_jac, inverse=True, **kwargs)
- return data
+ return data, log_det_jac
def __call__(
self, data: Mapping[str, any], *, inverse: bool = False, stage="inference", **kwargs
- ) -> dict[str, np.ndarray]:
+ ) -> dict[str, np.ndarray] | tuple[dict[str, np.ndarray], dict[str, np.ndarray]]:
"""Apply the transforms in the given direction.
Parameters
@@ -145,8 +166,8 @@ def __call__(
Returns
-------
- dict
- The transformed data.
+ dict | tuple[dict, dict]
+ The transformed data or tuple of transformed data and log determinant of the Jacobian.
"""
if inverse:
return self.inverse(data, stage=stage, **kwargs)
diff --git a/bayesflow/adapters/transforms/as_set.py b/bayesflow/adapters/transforms/as_set.py
index f4d5bdfc5..903536bc4 100644
--- a/bayesflow/adapters/transforms/as_set.py
+++ b/bayesflow/adapters/transforms/as_set.py
@@ -5,7 +5,7 @@
from .elementwise_transform import ElementwiseTransform
-@serializable
+@serializable("bayesflow.adapters")
class AsSet(ElementwiseTransform):
"""The `.as_set(["x", "y"])` transform indicates that both `x` and `y` are treated as sets.
diff --git a/bayesflow/adapters/transforms/as_time_series.py b/bayesflow/adapters/transforms/as_time_series.py
index 3f4d2a2c5..d7791352c 100644
--- a/bayesflow/adapters/transforms/as_time_series.py
+++ b/bayesflow/adapters/transforms/as_time_series.py
@@ -5,7 +5,7 @@
from .elementwise_transform import ElementwiseTransform
-@serializable
+@serializable("bayesflow.adapters")
class AsTimeSeries(ElementwiseTransform):
"""The `.as_time_series` transform can be used to indicate that variables shall be treated as time series.
diff --git a/bayesflow/adapters/transforms/broadcast.py b/bayesflow/adapters/transforms/broadcast.py
index 8667ec0c7..646e1f72e 100644
--- a/bayesflow/adapters/transforms/broadcast.py
+++ b/bayesflow/adapters/transforms/broadcast.py
@@ -6,7 +6,7 @@
from .transform import Transform
-@serializable
+@serializable("bayesflow.adapters")
class Broadcast(Transform):
"""
Broadcasts arrays or scalars to the shape of a given other array.
diff --git a/bayesflow/adapters/transforms/concatenate.py b/bayesflow/adapters/transforms/concatenate.py
index deb54fc3f..ac3700616 100644
--- a/bayesflow/adapters/transforms/concatenate.py
+++ b/bayesflow/adapters/transforms/concatenate.py
@@ -7,7 +7,7 @@
from .transform import Transform
-@serializable
+@serializable("bayesflow.adapters")
class Concatenate(Transform):
"""Concatenate multiple arrays into a new key. Used to specify how data variables should be treated by the network.
@@ -115,3 +115,37 @@ def extra_repr(self) -> str:
result += f", axis={self.axis}"
return result
+
+ def log_det_jac(
+ self,
+ data: dict[str, np.ndarray],
+ log_det_jac: dict[str, np.ndarray],
+ *,
+ strict: bool = False,
+ inverse: bool = False,
+ **kwargs,
+ ) -> dict[str, np.ndarray]:
+ # copy to avoid side effects
+ log_det_jac = log_det_jac.copy()
+
+ if inverse:
+ if log_det_jac.get(self.into) is not None:
+ raise ValueError(
+ "Cannot obtain an inverse Jacobian of concatenation. "
+ "Transform your variables before you concatenate."
+ )
+
+ return log_det_jac
+
+ required_keys = set(self.keys)
+ available_keys = set(log_det_jac.keys())
+ common_keys = available_keys & required_keys
+
+ if len(common_keys) == 0:
+ return log_det_jac
+
+ parts = [log_det_jac.pop(key) for key in common_keys]
+
+ log_det_jac[self.into] = sum(parts)
+
+ return log_det_jac
diff --git a/bayesflow/adapters/transforms/constrain.py b/bayesflow/adapters/transforms/constrain.py
index 5f93135a1..d01211dfc 100644
--- a/bayesflow/adapters/transforms/constrain.py
+++ b/bayesflow/adapters/transforms/constrain.py
@@ -11,7 +11,7 @@
from .elementwise_transform import ElementwiseTransform
-@serializable
+@serializable("bayesflow.adapters")
class Constrain(ElementwiseTransform):
"""
Constrains neural network predictions of a data variable to specified bounds.
@@ -87,6 +87,11 @@ def constrain(x):
def unconstrain(x):
return inverse_sigmoid((x - lower) / (upper - lower))
+
+ def ldj(x):
+ x = (x - lower) / (upper - lower)
+ return -np.log(x) - np.log1p(-x) - np.log(upper - lower)
+
case str() as name:
raise ValueError(f"Unsupported method name for double bounded constraint: '{name}'.")
case other:
@@ -101,6 +106,11 @@ def constrain(x):
def unconstrain(x):
return inverse_softplus(x - lower)
+
+ def ldj(x):
+ x = x - lower
+ return x - np.log(np.exp(x) - 1)
+
case "exp" | "log":
def constrain(x):
@@ -108,6 +118,10 @@ def constrain(x):
def unconstrain(x):
return np.log(x - lower)
+
+ def ldj(x):
+ return -np.log(x - lower)
+
case str() as name:
raise ValueError(f"Unsupported method name for single bounded constraint: '{name}'.")
case other:
@@ -122,6 +136,11 @@ def constrain(x):
def unconstrain(x):
return -inverse_softplus(-(x - upper))
+
+ def ldj(x):
+ x = -(x - upper)
+ return x - np.log(np.exp(x) - 1)
+
case "exp" | "log":
def constrain(x):
@@ -129,6 +148,9 @@ def constrain(x):
def unconstrain(x):
return -np.log(-x + upper)
+
+ def ldj(x):
+ return -np.log(-x + upper)
case str() as name:
raise ValueError(f"Unsupported method name for single bounded constraint: '{name}'.")
case other:
@@ -142,6 +164,7 @@ def unconstrain(x):
self.constrain = constrain
self.unconstrain = unconstrain
+ self.ldj = ldj
# do this last to avoid serialization issues
match inclusive:
@@ -178,3 +201,9 @@ def forward(self, data: np.ndarray, **kwargs) -> np.ndarray:
def inverse(self, data: np.ndarray, **kwargs) -> np.ndarray:
# inverse means network space -> data space, so constrain the data
return self.constrain(data)
+
+ def log_det_jac(self, data: np.ndarray, inverse: bool = False, **kwargs) -> np.ndarray:
+ ldj = self.ldj(data)
+ if inverse:
+ ldj = -ldj
+ return np.sum(ldj, axis=tuple(range(1, ldj.ndim)))
diff --git a/bayesflow/adapters/transforms/convert_dtype.py b/bayesflow/adapters/transforms/convert_dtype.py
index e68815269..8cd21b4cc 100644
--- a/bayesflow/adapters/transforms/convert_dtype.py
+++ b/bayesflow/adapters/transforms/convert_dtype.py
@@ -5,7 +5,7 @@
from .elementwise_transform import ElementwiseTransform
-@serializable
+@serializable("bayesflow.adapters")
class ConvertDType(ElementwiseTransform):
"""
Default transform used to convert all floats from float64 to float32 to be in line with keras framework.
diff --git a/bayesflow/adapters/transforms/drop.py b/bayesflow/adapters/transforms/drop.py
index 51615d632..5073027e6 100644
--- a/bayesflow/adapters/transforms/drop.py
+++ b/bayesflow/adapters/transforms/drop.py
@@ -5,7 +5,7 @@
from .transform import Transform
-@serializable
+@serializable("bayesflow.adapters")
class Drop(Transform):
"""
Transform to drop variables from further calculation.
@@ -46,3 +46,6 @@ def inverse(self, data: dict[str, any], **kwargs) -> dict[str, any]:
def extra_repr(self) -> str:
return "[" + ", ".join(map(repr, self.keys)) + "]"
+
+ def log_det_jac(self, data: dict[str, any], log_det_jac: dict[str, any], inverse: bool = False, **kwargs):
+ return self.inverse(data=log_det_jac) if inverse else self.forward(data=log_det_jac)
diff --git a/bayesflow/adapters/transforms/elementwise_transform.py b/bayesflow/adapters/transforms/elementwise_transform.py
index 3bde5a1da..020301749 100644
--- a/bayesflow/adapters/transforms/elementwise_transform.py
+++ b/bayesflow/adapters/transforms/elementwise_transform.py
@@ -3,7 +3,7 @@
from bayesflow.utils.serialization import serializable, deserialize
-@serializable
+@serializable("bayesflow.adapters")
class ElementwiseTransform:
"""Base class on which other transforms are based"""
@@ -25,3 +25,6 @@ def forward(self, data: np.ndarray, **kwargs) -> np.ndarray:
def inverse(self, data: np.ndarray, **kwargs) -> np.ndarray:
raise NotImplementedError
+
+ def log_det_jac(self, data: np.ndarray, inverse: bool = False, **kwargs) -> np.ndarray | None:
+ return None
diff --git a/bayesflow/adapters/transforms/expand_dims.py b/bayesflow/adapters/transforms/expand_dims.py
index e44d133b8..0f4151d37 100644
--- a/bayesflow/adapters/transforms/expand_dims.py
+++ b/bayesflow/adapters/transforms/expand_dims.py
@@ -5,7 +5,7 @@
from .elementwise_transform import ElementwiseTransform
-@serializable
+@serializable("bayesflow.adapters")
class ExpandDims(ElementwiseTransform):
"""
Expand the shape of an array.
diff --git a/bayesflow/adapters/transforms/filter_transform.py b/bayesflow/adapters/transforms/filter_transform.py
index e1920e73c..4dc2c8008 100644
--- a/bayesflow/adapters/transforms/filter_transform.py
+++ b/bayesflow/adapters/transforms/filter_transform.py
@@ -14,7 +14,7 @@ def __call__(self, key: str, value: np.ndarray, inverse: bool) -> bool:
raise NotImplementedError
-@serializable
+@serializable("bayesflow.adapters")
class FilterTransform(Transform):
"""
Implements a transform that applies a different transform on a subset of the data.
@@ -150,9 +150,35 @@ def _should_transform(self, key: str, value: np.ndarray, inverse: bool = False)
return predicate(key, value, inverse=inverse)
def _apply_transform(self, key: str, value: np.ndarray, inverse: bool = False, **kwargs) -> np.ndarray:
+ transform = self._get_transform(key)
+
+ return transform(value, inverse=inverse, **kwargs)
+
+ def _get_transform(self, key: str) -> ElementwiseTransform:
if key not in self.transform_map:
self.transform_map[key] = self.transform_constructor(**self.kwargs)
- transform = self.transform_map[key]
+ return self.transform_map[key]
- return transform(value, inverse=inverse, **kwargs)
+ def log_det_jac(
+ self, data: dict[str, np.ndarray], log_det_jac: dict[str, np.ndarray], *, strict: bool = True, **kwargs
+ ):
+ data = data.copy()
+
+ if strict and self.include is not None:
+ missing_keys = set(self.include) - set(data.keys())
+ if missing_keys:
+ raise KeyError(f"Missing keys from include list: {missing_keys!r}")
+
+ for key, value in data.items():
+ if self._should_transform(key, value, inverse=False):
+ transform = self._get_transform(key)
+ ldj = transform.log_det_jac(value, **kwargs)
+ if ldj is None:
+ continue
+ elif key in log_det_jac:
+ log_det_jac[key] += ldj
+ else:
+ log_det_jac[key] = ldj
+
+ return log_det_jac
diff --git a/bayesflow/adapters/transforms/keep.py b/bayesflow/adapters/transforms/keep.py
index 62373071f..c69d01ca3 100644
--- a/bayesflow/adapters/transforms/keep.py
+++ b/bayesflow/adapters/transforms/keep.py
@@ -5,7 +5,7 @@
from .transform import Transform
-@serializable
+@serializable("bayesflow.adapters")
class Keep(Transform):
"""
Name the data parameters that should be kept for futher calculation.
@@ -57,3 +57,6 @@ def inverse(self, data: dict[str, any], **kwargs) -> dict[str, any]:
def extra_repr(self) -> str:
return "[" + ", ".join(map(repr, self.keys)) + "]"
+
+ def log_det_jac(self, data: dict[str, any], log_det_jac: dict[str, any], inverse: bool = False, **kwargs):
+ return self.inverse(data=log_det_jac) if inverse else self.forward(data=log_det_jac)
diff --git a/bayesflow/adapters/transforms/log.py b/bayesflow/adapters/transforms/log.py
index 3184ab979..a42c43ef0 100644
--- a/bayesflow/adapters/transforms/log.py
+++ b/bayesflow/adapters/transforms/log.py
@@ -5,7 +5,7 @@
from .elementwise_transform import ElementwiseTransform
-@serializable
+@serializable("bayesflow.adapters")
class Log(ElementwiseTransform):
"""Log transforms a variable.
@@ -37,3 +37,12 @@ def inverse(self, data: np.ndarray, **kwargs) -> np.ndarray:
def get_config(self) -> dict:
return serialize({"p1": self.p1})
+
+ def log_det_jac(self, data: np.ndarray, inverse: bool = False, **kwargs) -> np.ndarray:
+ if self.p1:
+ ldj = -np.log1p(data)
+ else:
+ ldj = -np.log(data)
+ if inverse:
+ ldj = -ldj
+ return np.sum(ldj, axis=tuple(range(1, ldj.ndim)))
diff --git a/bayesflow/adapters/transforms/map_transform.py b/bayesflow/adapters/transforms/map_transform.py
index 7820ce611..15c5c945d 100644
--- a/bayesflow/adapters/transforms/map_transform.py
+++ b/bayesflow/adapters/transforms/map_transform.py
@@ -6,7 +6,7 @@
from .transform import Transform
-@serializable
+@serializable("bayesflow.adapters")
class MapTransform(Transform):
"""
Implements a transform that applies a set of elementwise transforms
@@ -41,12 +41,8 @@ def get_config(self) -> dict:
def forward(self, data: dict[str, np.ndarray], *, strict: bool = True, **kwargs) -> dict[str, np.ndarray]:
data = data.copy()
- required_keys = set(self.transform_map.keys())
- available_keys = set(data.keys())
- missing_keys = required_keys - available_keys
-
- if strict and missing_keys:
- raise KeyError(f"Missing keys: {missing_keys!r}")
+ if strict:
+ self._check_keys(data)
for key, transform in self.transform_map.items():
if key in data:
@@ -57,15 +53,40 @@ def forward(self, data: dict[str, np.ndarray], *, strict: bool = True, **kwargs)
def inverse(self, data: dict[str, np.ndarray], *, strict: bool = False, **kwargs) -> dict[str, np.ndarray]:
data = data.copy()
- required_keys = set(self.transform_map.keys())
- available_keys = set(data.keys())
- missing_keys = required_keys - available_keys
-
- if strict and missing_keys:
- raise KeyError(f"Missing keys: {missing_keys!r}")
+ if strict:
+ self._check_keys(data)
for key, transform in self.transform_map.items():
if key in data:
data[key] = transform.inverse(data[key], **kwargs)
return data
+
+ def log_det_jac(
+ self, data: dict[str, np.ndarray], log_det_jac: dict[str, np.ndarray], *, strict: bool = True, **kwargs
+ ) -> dict[str, np.ndarray]:
+ data = data.copy()
+
+ if strict:
+ self._check_keys(data)
+
+ for key, transform in self.transform_map.items():
+ if key in data:
+ ldj = transform.log_det_jac(data[key], **kwargs)
+
+ if ldj is None:
+ continue
+ elif key in log_det_jac:
+ log_det_jac[key] += ldj
+ else:
+ log_det_jac[key] = ldj
+
+ return log_det_jac
+
+ def _check_keys(self, data: dict[str, np.ndarray]):
+ required_keys = set(self.transform_map.keys())
+ available_keys = set(data.keys())
+ missing_keys = required_keys - available_keys
+
+ if missing_keys:
+ raise KeyError(f"Missing keys: {missing_keys!r}")
diff --git a/bayesflow/adapters/transforms/numpy_transform.py b/bayesflow/adapters/transforms/numpy_transform.py
index aecf03bba..a19216dd2 100644
--- a/bayesflow/adapters/transforms/numpy_transform.py
+++ b/bayesflow/adapters/transforms/numpy_transform.py
@@ -5,7 +5,7 @@
from .elementwise_transform import ElementwiseTransform
-@serializable
+@serializable("bayesflow.adapters")
class NumpyTransform(ElementwiseTransform):
"""
A class to apply element-wise transformations using plain NumPy functions.
@@ -72,3 +72,6 @@ def forward(self, data: dict[str, any], **kwargs) -> dict[str, any]:
def inverse(self, data: np.ndarray, **kwargs) -> np.ndarray:
return self._inverse(data)
+
+ def log_det_jac(self, data, inverse=False, **kwargs):
+ raise NotImplementedError("log determinant of the Jacobian of the numpy transforms are not implemented yet")
diff --git a/bayesflow/adapters/transforms/one_hot.py b/bayesflow/adapters/transforms/one_hot.py
index e097a28f9..bbed8bf5d 100644
--- a/bayesflow/adapters/transforms/one_hot.py
+++ b/bayesflow/adapters/transforms/one_hot.py
@@ -6,7 +6,7 @@
from .elementwise_transform import ElementwiseTransform
-@serializable
+@serializable("bayesflow.adapters")
class OneHot(ElementwiseTransform):
"""
Changes data to be one-hot encoded.
diff --git a/bayesflow/adapters/transforms/rename.py b/bayesflow/adapters/transforms/rename.py
index 49cc52eba..bec3388b0 100644
--- a/bayesflow/adapters/transforms/rename.py
+++ b/bayesflow/adapters/transforms/rename.py
@@ -3,7 +3,7 @@
from .transform import Transform
-@serializable
+@serializable("bayesflow.adapters")
class Rename(Transform):
"""
Transform to rename keys in data dictionary. Useful to rename variables to match those required by
@@ -58,3 +58,6 @@ def inverse(self, data: dict[str, any], *, strict: bool = False, **kwargs) -> di
def extra_repr(self) -> str:
return f"{self.from_key!r} -> {self.to_key!r}"
+
+ def log_det_jac(self, data: dict[str, any], log_det_jac: dict[str, any], inverse: bool = False, **kwargs):
+ return self.inverse(data=log_det_jac) if inverse else self.forward(data=log_det_jac, strict=False)
diff --git a/bayesflow/adapters/transforms/scale.py b/bayesflow/adapters/transforms/scale.py
index 8d9dce1be..d7c1aa2a7 100644
--- a/bayesflow/adapters/transforms/scale.py
+++ b/bayesflow/adapters/transforms/scale.py
@@ -5,7 +5,7 @@
from .elementwise_transform import ElementwiseTransform
-@serializable
+@serializable("bayesflow.adapters")
class Scale(ElementwiseTransform):
def __init__(self, scale: np.typing.ArrayLike):
self.scale = np.array(scale)
@@ -18,3 +18,10 @@ def forward(self, data: np.ndarray, **kwargs) -> np.ndarray:
def inverse(self, data: np.ndarray, **kwargs) -> np.ndarray:
return data / self.scale
+
+ def log_det_jac(self, data: np.ndarray, inverse: bool = False, **kwargs) -> np.ndarray:
+ ldj = np.log(np.abs(self.scale))
+ ldj = np.full(data.shape, ldj)
+ if inverse:
+ ldj = -ldj
+ return np.sum(ldj, axis=tuple(range(1, ldj.ndim)))
diff --git a/bayesflow/adapters/transforms/serializable_custom_transform.py b/bayesflow/adapters/transforms/serializable_custom_transform.py
index 75d588afd..248fc0ec5 100644
--- a/bayesflow/adapters/transforms/serializable_custom_transform.py
+++ b/bayesflow/adapters/transforms/serializable_custom_transform.py
@@ -1,18 +1,17 @@
from collections.abc import Callable
import numpy as np
from keras.saving import (
- deserialize_keras_object as deserialize,
- register_keras_serializable as serializable,
- serialize_keras_object as serialize,
get_registered_name,
get_registered_object,
)
+
+from bayesflow.utils.serialization import deserialize, serializable, serialize
from .elementwise_transform import ElementwiseTransform
from ...utils import filter_kwargs
import inspect
-@serializable(package="bayesflow.adapters")
+@serializable("bayesflow.adapters")
class SerializableCustomTransform(ElementwiseTransform):
"""
Transforms a parameter using a pair of registered serializable forward and inverse functions.
diff --git a/bayesflow/adapters/transforms/shift.py b/bayesflow/adapters/transforms/shift.py
index b7c9659d2..5923b4e49 100644
--- a/bayesflow/adapters/transforms/shift.py
+++ b/bayesflow/adapters/transforms/shift.py
@@ -5,7 +5,7 @@
from .elementwise_transform import ElementwiseTransform
-@serializable
+@serializable("bayesflow.adapters")
class Shift(ElementwiseTransform):
def __init__(self, shift: np.typing.ArrayLike):
self.shift = np.array(shift)
diff --git a/bayesflow/adapters/transforms/split.py b/bayesflow/adapters/transforms/split.py
index 919db4e08..4c0ae9f65 100644
--- a/bayesflow/adapters/transforms/split.py
+++ b/bayesflow/adapters/transforms/split.py
@@ -6,7 +6,7 @@
from .transform import Transform
-@serializable
+@serializable("bayesflow.adapters")
class Split(Transform):
"""This is the effective inverse of the :py:class:`~Concatenate` Transform.
diff --git a/bayesflow/adapters/transforms/sqrt.py b/bayesflow/adapters/transforms/sqrt.py
index 617f892bc..bcfe49136 100644
--- a/bayesflow/adapters/transforms/sqrt.py
+++ b/bayesflow/adapters/transforms/sqrt.py
@@ -5,7 +5,7 @@
from .elementwise_transform import ElementwiseTransform
-@serializable
+@serializable("bayesflow.adapters")
class Sqrt(ElementwiseTransform):
"""Square-root transform a variable.
@@ -22,3 +22,9 @@ def inverse(self, data: np.ndarray, **kwargs) -> np.ndarray:
def get_config(self) -> dict:
return {}
+
+ def log_det_jac(self, data: np.ndarray, inverse: bool = False, **kwargs) -> np.ndarray:
+ ldj = -0.5 * np.log(data) - np.log(2)
+ if inverse:
+ ldj = -ldj
+ return np.sum(ldj, axis=tuple(range(1, ldj.ndim)))
diff --git a/bayesflow/adapters/transforms/standardize.py b/bayesflow/adapters/transforms/standardize.py
index 52899917c..a1c3c5a3d 100644
--- a/bayesflow/adapters/transforms/standardize.py
+++ b/bayesflow/adapters/transforms/standardize.py
@@ -7,7 +7,7 @@
from .elementwise_transform import ElementwiseTransform
-@serializable
+@serializable("bayesflow.adapters")
class Standardize(ElementwiseTransform):
"""
Transform that when applied standardizes data using typical z-score standardization
@@ -120,3 +120,10 @@ def inverse(self, data: np.ndarray, **kwargs) -> np.ndarray:
std = np.broadcast_to(self.std, data.shape)
return data * std + mean
+
+ def log_det_jac(self, data, inverse: bool = False, **kwargs) -> np.ndarray:
+ std = np.broadcast_to(self.std, data.shape)
+ ldj = np.log(np.abs(std))
+ if inverse:
+ ldj = -ldj
+ return np.sum(ldj, axis=tuple(range(1, ldj.ndim)))
diff --git a/bayesflow/adapters/transforms/to_array.py b/bayesflow/adapters/transforms/to_array.py
index 9d5381ca0..fe1b82f2d 100644
--- a/bayesflow/adapters/transforms/to_array.py
+++ b/bayesflow/adapters/transforms/to_array.py
@@ -7,7 +7,7 @@
from .elementwise_transform import ElementwiseTransform
-@serializable
+@serializable("bayesflow.adapters")
class ToArray(ElementwiseTransform):
"""
Checks provided data for any non-arrays and converts them to numpy arrays.
diff --git a/bayesflow/adapters/transforms/to_dict.py b/bayesflow/adapters/transforms/to_dict.py
index 6babb2a40..cfc4ec00d 100644
--- a/bayesflow/adapters/transforms/to_dict.py
+++ b/bayesflow/adapters/transforms/to_dict.py
@@ -6,7 +6,7 @@
from .transform import Transform
-@serializable
+@serializable("bayesflow.adapters")
class ToDict(Transform):
"""Convert non-dict batches (e.g., pandas.DataFrame) to dict batches"""
diff --git a/bayesflow/adapters/transforms/transform.py b/bayesflow/adapters/transforms/transform.py
index 4642c1165..0bc6331bc 100644
--- a/bayesflow/adapters/transforms/transform.py
+++ b/bayesflow/adapters/transforms/transform.py
@@ -3,7 +3,7 @@
from bayesflow.utils.serialization import serializable, deserialize
-@serializable
+@serializable("bayesflow.adapters")
class Transform:
"""
Base class on which other transforms are based
@@ -35,3 +35,8 @@ def inverse(self, data: dict[str, np.ndarray], **kwargs) -> dict[str, np.ndarray
def extra_repr(self) -> str:
return ""
+
+ def log_det_jac(
+ self, data: dict[str, np.ndarray], log_det_jac: dict[str, np.ndarray], inverse: bool = False, **kwargs
+ ) -> dict[str, np.ndarray]:
+ return log_det_jac
diff --git a/bayesflow/approximators/continuous_approximator.py b/bayesflow/approximators/continuous_approximator.py
index dcb661ca0..3e43a8917 100644
--- a/bayesflow/approximators/continuous_approximator.py
+++ b/bayesflow/approximators/continuous_approximator.py
@@ -13,7 +13,7 @@
from .approximator import Approximator
-@serializable
+@serializable("bayesflow.approximators")
class ContinuousApproximator(Approximator):
"""
Defines a workflow for performing fast posterior or likelihood inference.
@@ -400,6 +400,39 @@ def _sample(
**filter_kwargs(kwargs, self.inference_network.sample),
)
+ def summaries(self, data: Mapping[str, np.ndarray], **kwargs):
+ """
+ Computes the summaries of given data.
+
+ The `data` dictionary is preprocessed using the `adapter` and passed through the summary network.
+
+ Parameters
+ ----------
+ data : Mapping[str, np.ndarray]
+ Dictionary of data as NumPy arrays.
+ **kwargs : dict
+ Additional keyword arguments for the adapter and the summary network.
+
+ Returns
+ -------
+ summaries : np.ndarray
+ Log-probabilities of the distribution `p(inference_variables | inference_conditions, h(summary_conditions))`
+
+ Raises
+ ------
+ ValueError
+ If the approximator does not have a summary network, or the adapter does not produce the output required
+ by the summary network.
+ """
+ if self.summary_network is None:
+ raise ValueError("A summary network is required to compute summeries.")
+ data_adapted = self.adapter(data, strict=False, stage="inference", **kwargs)
+ if "summary_variables" not in data_adapted or data_adapted["summary_variables"] is None:
+ raise ValueError("Summary variables are required to compute summaries.")
+ summary_variables = keras.ops.convert_to_tensor(data_adapted["summary_variables"])
+ summaries = self.summary_network(summary_variables, **filter_kwargs(kwargs, self.summary_network.call))
+ return summaries
+
def log_prob(self, data: Mapping[str, np.ndarray], **kwargs) -> np.ndarray | dict[str, np.ndarray]:
"""
Computes the log-probability of given data under the model. The `data` dictionary is preprocessed using the
@@ -417,11 +450,16 @@ def log_prob(self, data: Mapping[str, np.ndarray], **kwargs) -> np.ndarray | dic
np.ndarray
Log-probabilities of the distribution `p(inference_variables | inference_conditions, h(summary_conditions))`
"""
- data = self.adapter(data, strict=False, stage="inference", **kwargs)
+ data, log_det_jac = self.adapter(data, strict=False, stage="inference", log_det_jac=True, **kwargs)
data = keras.tree.map_structure(keras.ops.convert_to_tensor, data)
log_prob = self._log_prob(**data, **kwargs)
log_prob = keras.tree.map_structure(keras.ops.convert_to_numpy, log_prob)
+ # change of variables formula
+ log_det_jac = log_det_jac.get("inference_variables")
+ if log_det_jac is not None:
+ log_prob = log_prob + log_det_jac
+
return log_prob
def _log_prob(
diff --git a/bayesflow/approximators/model_comparison_approximator.py b/bayesflow/approximators/model_comparison_approximator.py
index 1b9d198ff..86af4ee33 100644
--- a/bayesflow/approximators/model_comparison_approximator.py
+++ b/bayesflow/approximators/model_comparison_approximator.py
@@ -14,7 +14,7 @@
from .approximator import Approximator
-@serializable
+@serializable("bayesflow.approximators")
class ModelComparisonApproximator(Approximator):
"""
Defines an approximator for model (simulator) comparison, where the (discrete) posterior model probabilities are
@@ -345,3 +345,36 @@ def _predict(self, classifier_conditions: Tensor = None, summary_variables: Tens
output = self.logits_projector(output)
return output
+
+ def summaries(self, data: Mapping[str, np.ndarray], **kwargs):
+ """
+ Computes the summaries of given data.
+
+ The `data` dictionary is preprocessed using the `adapter` and passed through the summary network.
+
+ Parameters
+ ----------
+ data : Mapping[str, np.ndarray]
+ Dictionary of data as NumPy arrays.
+ **kwargs : dict
+ Additional keyword arguments for the adapter and the summary network.
+
+ Returns
+ -------
+ summaries : np.ndarray
+ Log-probabilities of the distribution `p(inference_variables | inference_conditions, h(summary_conditions))`
+
+ Raises
+ ------
+ ValueError
+ If the approximator does not have a summary network, or the adapter does not produce the output required
+ by the summary network.
+ """
+ if self.summary_network is None:
+ raise ValueError("A summary network is required to compute summaries.")
+ data_adapted = self.adapter(data, strict=False, stage="inference", **kwargs)
+ if "summary_variables" not in data_adapted or data_adapted["summary_variables"] is None:
+ raise ValueError("Summary variables are required to compute summaries.")
+ summary_variables = keras.ops.convert_to_tensor(data_adapted["summary_variables"])
+ summaries = self.summary_network(summary_variables, **filter_kwargs(kwargs, self.summary_network.call))
+ return summaries
diff --git a/bayesflow/approximators/point_approximator.py b/bayesflow/approximators/point_approximator.py
index 1e407e2a6..b3d90781c 100644
--- a/bayesflow/approximators/point_approximator.py
+++ b/bayesflow/approximators/point_approximator.py
@@ -11,7 +11,7 @@
from .continuous_approximator import ContinuousApproximator
-@serializable
+@serializable("bayesflow.approximators")
class PointApproximator(ContinuousApproximator):
"""
A workflow for fast amortized point estimation of a conditional distribution.
diff --git a/bayesflow/diagnostics/__init__.py b/bayesflow/diagnostics/__init__.py
index 1e13e11f2..87823c754 100644
--- a/bayesflow/diagnostics/__init__.py
+++ b/bayesflow/diagnostics/__init__.py
@@ -1,8 +1,13 @@
-"""
+r"""
A collection of plotting utilities and metrics for evaluating trained :py:class:`~bayesflow.workflows.Workflow`\ s.
"""
-from .metrics import root_mean_squared_error, calibration_error, posterior_contraction
+from .metrics import (
+ bootstrap_comparison,
+ calibration_error,
+ posterior_contraction,
+ summary_space_comparison,
+)
from .plots import (
calibration_ecdf,
diff --git a/bayesflow/diagnostics/metrics/__init__.py b/bayesflow/diagnostics/metrics/__init__.py
index ceeca4cc4..3e3496cda 100644
--- a/bayesflow/diagnostics/metrics/__init__.py
+++ b/bayesflow/diagnostics/metrics/__init__.py
@@ -3,3 +3,4 @@
from .root_mean_squared_error import root_mean_squared_error
from .expected_calibration_error import expected_calibration_error
from .classifier_two_sample_test import classifier_two_sample_test
+from .model_misspecification import bootstrap_comparison, summary_space_comparison
diff --git a/bayesflow/diagnostics/metrics/model_misspecification.py b/bayesflow/diagnostics/metrics/model_misspecification.py
new file mode 100644
index 000000000..c698d4eb2
--- /dev/null
+++ b/bayesflow/diagnostics/metrics/model_misspecification.py
@@ -0,0 +1,155 @@
+"""
+This module provides functions for computing distances between observation samples and reference samples with distance
+distributions within the reference samples for hypothesis testing.
+"""
+
+from collections.abc import Mapping, Callable
+
+import numpy as np
+from keras.ops import convert_to_numpy, convert_to_tensor
+
+from bayesflow.approximators import ContinuousApproximator
+from bayesflow.metrics.functional import maximum_mean_discrepancy
+from bayesflow.types import Tensor
+
+
+def bootstrap_comparison(
+ observed_samples: np.ndarray,
+ reference_samples: np.ndarray,
+ comparison_fn: Callable[[Tensor, Tensor], Tensor],
+ num_null_samples: int = 100,
+) -> tuple[float, np.ndarray]:
+ """Computes the distance between observed and reference samples and generates a distribution of null sample
+ distances by bootstrapping for hypothesis testing.
+
+ Parameters
+ ----------
+ observed_samples : np.ndarray)
+ Observed samples, shape (num_observed, ...).
+ reference_samples : np.ndarray
+ Reference samples, shape (num_reference, ...).
+ comparison_fn : Callable[[Tensor, Tensor], Tensor]
+ Function to compute the distance metric.
+ num_null_samples : int
+ Number of null samples to generate for hypothesis testing. Default is 100.
+
+ Returns
+ -------
+ distance_observed : float
+ The distance value between observed and reference samples.
+ distance_null : np.ndarray
+ A distribution of distance values under the null hypothesis.
+
+ Raises
+ ------
+ ValueError
+ - If the number of number of observed samples exceeds the number of reference samples
+ - If the shapes of observed and reference samples do not match on dimensions besides the first one.
+ """
+ num_observed: int = observed_samples.shape[0]
+ num_reference: int = reference_samples.shape[0]
+
+ if num_observed > num_reference:
+ raise ValueError(
+ f"Number of observed samples ({num_observed}) cannot exceed"
+ f"the number of reference samples ({num_reference}) for bootstrapping."
+ )
+ if observed_samples.shape[1:] != reference_samples.shape[1:]:
+ raise ValueError(
+ f"Expected observed and reference samples to have the same shape, "
+ f"but got {observed_samples.shape[1:]} != {reference_samples.shape[1:]}."
+ )
+
+ observed_samples_tensor: Tensor = convert_to_tensor(observed_samples, dtype="float32")
+ reference_samples_tensor: Tensor = convert_to_tensor(reference_samples, dtype="float32")
+
+ distance_null_samples: np.ndarray = np.zeros(num_null_samples, dtype=np.float64)
+ for i in range(num_null_samples):
+ bootstrap_idx: np.ndarray = np.random.randint(0, num_reference, size=num_observed)
+ bootstrap_samples: np.ndarray = reference_samples[bootstrap_idx]
+ bootstrap_samples_tensor: Tensor = convert_to_tensor(bootstrap_samples, dtype="float32")
+ distance_null_samples[i] = convert_to_numpy(comparison_fn(bootstrap_samples_tensor, reference_samples_tensor))
+
+ distance_observed_tensor: Tensor = comparison_fn(
+ observed_samples_tensor,
+ reference_samples_tensor,
+ )
+
+ distance_observed: float = float(convert_to_numpy(distance_observed_tensor))
+
+ return distance_observed, distance_null_samples
+
+
+def summary_space_comparison(
+ observed_data: Mapping[str, np.ndarray],
+ reference_data: Mapping[str, np.ndarray],
+ approximator: ContinuousApproximator,
+ num_null_samples: int = 100,
+ comparison_fn: Callable = maximum_mean_discrepancy,
+ **kwargs,
+) -> tuple[float, np.ndarray]:
+ """Computes the distance between observed and reference data in the summary space and
+ generates a distribution of distance values under the null hypothesis to assess model misspecification.
+
+ By default, the Maximum Mean Discrepancy (MMD) is used as a distance function.
+
+ [1] M. Schmitt, P.-C. Bürkner, U. Köthe, and S. T. Radev, "Detecting model misspecification in amortized Bayesian
+ inference with neural networks," arXiv e-prints, Dec. 2021, Art. no. arXiv:2112.08866.
+ URL: https://arxiv.org/abs/2112.08866
+
+ Parameters
+ ----------
+ observed_data : dict[str, np.ndarray]
+ Dictionary of observed data as NumPy arrays, which will be preprocessed by the approximators adapter and passed
+ through its summary network.
+ reference_data : dict[str, np.ndarray]
+ Dictionary of reference data as NumPy arrays, which will be preprocessed by the approximators adapter and passed
+ through its summary network.
+ approximator : ContinuousApproximator
+ An instance of :py:class:`~bayesflow.approximators.ContinuousApproximator` used to compute summary statistics
+ from the data.
+ num_null_samples : int, optional
+ Number of null samples to generate for hypothesis testing. Default is 100.
+ comparison_fn : Callable, optional
+ Distance function to compare the data in the summary space.
+ **kwargs : dict
+ Additional keyword arguments for the adapter and sampling process.
+
+ Returns
+ -------
+ distance_observed : float
+ The MMD value between observed and reference summaries.
+ distance_null : np.ndarray
+ A distribution of MMD values under the null hypothesis.
+
+ Raises
+ ------
+ ValueError
+ If approximator is not an instance of ContinuousApproximator or does not have a summary network.
+ """
+
+ if not isinstance(approximator, ContinuousApproximator):
+ raise ValueError("The approximator must be an instance of ContinuousApproximator.")
+
+ if not hasattr(approximator, "summary_network") or approximator.summary_network is None:
+ comparison_fn_name = (
+ "bayesflow.metrics.functional.maximum_mean_discrepancy"
+ if comparison_fn is maximum_mean_discrepancy
+ else comparison_fn.__name__
+ )
+ raise ValueError(
+ "The approximator must have a summary network. If you have manually crafted summary "
+ "statistics, or want to compare raw data and not summary statistics, please use the "
+ f"`bootstrap_comparison` function with `comparison_fn={comparison_fn_name}` on the respective arrays."
+ )
+ observed_summaries = convert_to_numpy(approximator.summaries(observed_data))
+ reference_summaries = convert_to_numpy(approximator.summaries(reference_data))
+
+ distance_observed, distance_null = bootstrap_comparison(
+ observed_samples=observed_summaries,
+ reference_samples=reference_summaries,
+ comparison_fn=comparison_fn,
+ num_null_samples=num_null_samples,
+ )
+
+ return distance_observed, distance_null
diff --git a/bayesflow/distributions/diagonal_normal.py b/bayesflow/distributions/diagonal_normal.py
index 98a127b1c..f8d93b945 100644
--- a/bayesflow/distributions/diagonal_normal.py
+++ b/bayesflow/distributions/diagonal_normal.py
@@ -12,7 +12,7 @@
from .distribution import Distribution
-@serializable
+@serializable("bayesflow.distributions")
class DiagonalNormal(Distribution):
"""Implements a backend-agnostic diagonal Gaussian distribution."""
diff --git a/bayesflow/distributions/diagonal_student_t.py b/bayesflow/distributions/diagonal_student_t.py
index cd32a67fb..98e3fb7eb 100644
--- a/bayesflow/distributions/diagonal_student_t.py
+++ b/bayesflow/distributions/diagonal_student_t.py
@@ -13,7 +13,7 @@
from .distribution import Distribution
-@serializable
+@serializable("bayesflow.distributions")
class DiagonalStudentT(Distribution):
"""Implements a backend-agnostic diagonal Student-t distribution."""
diff --git a/bayesflow/distributions/distribution.py b/bayesflow/distributions/distribution.py
index 1d3a83962..3689f0d9f 100644
--- a/bayesflow/distributions/distribution.py
+++ b/bayesflow/distributions/distribution.py
@@ -5,7 +5,7 @@
from bayesflow.utils.serialization import serializable, deserialize
-@serializable
+@serializable("bayesflow.distributions")
class Distribution(keras.Layer):
def __init__(self, **kwargs):
super().__init__(**layer_kwargs(kwargs))
diff --git a/bayesflow/distributions/mixture.py b/bayesflow/distributions/mixture.py
index d7f6bd758..a7bf2ea27 100644
--- a/bayesflow/distributions/mixture.py
+++ b/bayesflow/distributions/mixture.py
@@ -11,7 +11,7 @@
from bayesflow.distributions import Distribution
-@serializable
+@serializable("bayesflow.distributions")
class Mixture(Distribution):
"""Utility class for a backend-agnostic mixture distributions."""
diff --git a/bayesflow/experimental/cif/cif.py b/bayesflow/experimental/cif/cif.py
index 8742501b3..bd776f93e 100644
--- a/bayesflow/experimental/cif/cif.py
+++ b/bayesflow/experimental/cif/cif.py
@@ -1,7 +1,7 @@
import keras
-from keras.saving import register_keras_serializable as serializable
from bayesflow.types import Shape, Tensor
+from bayesflow.utils.serialization import serializable
from bayesflow.networks.inference_network import InferenceNetwork
from bayesflow.networks.coupling_flow import CouplingFlow
@@ -9,7 +9,8 @@
from .conditional_gaussian import ConditionalGaussian
-@serializable(package="bayesflow.networks")
+# disable module check, use potential module after moving from experimental
+@serializable("bayesflow.networks", disable_module_check=True)
class CIF(InferenceNetwork):
"""Implements a continuously indexed flow (CIF) with a `CouplingFlow`
bijection and `ConditionalGaussian` distributions p and q. Improves on
diff --git a/bayesflow/experimental/cif/conditional_gaussian.py b/bayesflow/experimental/cif/conditional_gaussian.py
index d11f3fb65..ebba47a2e 100644
--- a/bayesflow/experimental/cif/conditional_gaussian.py
+++ b/bayesflow/experimental/cif/conditional_gaussian.py
@@ -1,13 +1,14 @@
import keras
-from keras.saving import register_keras_serializable
import numpy as np
from bayesflow.networks.mlp import MLP
from bayesflow.types import Shape, Tensor
from bayesflow.utils import layer_kwargs
+from bayesflow.utils.serialization import serializable
-@register_keras_serializable(package="bayesflow.networks.cif")
+# disable module check, use potential module after moving from experimental
+@serializable("bayesflow.networks", disable_module_check=True)
class ConditionalGaussian(keras.Layer):
"""Implements a conditional gaussian distribution with neural networks for
the means and standard deviations respectively. Bulit in reference to [1].
diff --git a/bayesflow/experimental/continuous_time_consistency_model.py b/bayesflow/experimental/continuous_time_consistency_model.py
index 54417cd07..b1c751454 100644
--- a/bayesflow/experimental/continuous_time_consistency_model.py
+++ b/bayesflow/experimental/continuous_time_consistency_model.py
@@ -22,7 +22,8 @@
from bayesflow.networks.embeddings import FourierEmbedding
-@serializable
+# disable module check, use potential module after moving from experimental
+@serializable("bayesflow.networks", disable_module_check=True)
class ContinuousTimeConsistencyModel(InferenceNetwork):
"""Implements an sCM (simple, stable, and scalable Consistency Model)
with continous-time Consistency Training (CT) as described in [1].
diff --git a/bayesflow/experimental/free_form_flow/free_form_flow.py b/bayesflow/experimental/free_form_flow/free_form_flow.py
index 61937d56f..12bb97b93 100644
--- a/bayesflow/experimental/free_form_flow/free_form_flow.py
+++ b/bayesflow/experimental/free_form_flow/free_form_flow.py
@@ -19,7 +19,8 @@
from bayesflow.networks import InferenceNetwork
-@serializable
+# disable module check, use potential module after moving from experimental
+@serializable("bayesflow.networks", disable_module_check=True)
class FreeFormFlow(InferenceNetwork):
"""Implements a dimensionality-preserving Free-form Flow.
Incorporates ideas from [1-2].
diff --git a/bayesflow/experimental/resnet/dense_resnet.py b/bayesflow/experimental/resnet/dense_resnet.py
index fa380969f..93ff59d3f 100644
--- a/bayesflow/experimental/resnet/dense_resnet.py
+++ b/bayesflow/experimental/resnet/dense_resnet.py
@@ -8,7 +8,8 @@
from .double_linear import DoubleLinear
-@serializable
+# disable module check, use potential module after moving from experimental
+@serializable("bayesflow.networks", disable_module_check=True)
class DenseResNet(keras.Sequential):
"""
Implements the fully-connected analogue of the ResNet architecture.
diff --git a/bayesflow/experimental/resnet/double_conv.py b/bayesflow/experimental/resnet/double_conv.py
index c70e37323..a2b6bbc88 100644
--- a/bayesflow/experimental/resnet/double_conv.py
+++ b/bayesflow/experimental/resnet/double_conv.py
@@ -5,7 +5,8 @@
from bayesflow.utils.serialization import deserialize, serializable, serialize
-@serializable
+# disable module check, use potential module after moving from experimental
+@serializable("bayesflow.networks", disable_module_check=True)
class DoubleConv(keras.Sequential):
def __init__(
self,
diff --git a/bayesflow/experimental/resnet/double_linear.py b/bayesflow/experimental/resnet/double_linear.py
index e2138c8b0..aae72fa39 100644
--- a/bayesflow/experimental/resnet/double_linear.py
+++ b/bayesflow/experimental/resnet/double_linear.py
@@ -5,7 +5,8 @@
from bayesflow.utils.serialization import deserialize, serializable, serialize
-@serializable
+# disable module check, use potential module after moving from experimental
+@serializable("bayesflow.networks", disable_module_check=True)
class DoubleLinear(keras.Sequential):
def __init__(
self,
diff --git a/bayesflow/experimental/resnet/resnet.py b/bayesflow/experimental/resnet/resnet.py
index 07f1f2cda..862e0ac98 100644
--- a/bayesflow/experimental/resnet/resnet.py
+++ b/bayesflow/experimental/resnet/resnet.py
@@ -8,7 +8,8 @@
from .double_conv import DoubleConv
-@serializable
+# disable module check, use potential module after moving from experimental
+@serializable("bayesflow.networks", disable_module_check=True)
class ResNet(keras.Sequential):
"""
Implements the ResNet architecture.
diff --git a/bayesflow/links/ordered.py b/bayesflow/links/ordered.py
index 77545b6f8..25caf5350 100644
--- a/bayesflow/links/ordered.py
+++ b/bayesflow/links/ordered.py
@@ -1,11 +1,11 @@
import keras
-from keras.saving import register_keras_serializable as serializable
from bayesflow.utils import layer_kwargs
from bayesflow.utils.decorators import sanitize_input_shape
+from bayesflow.utils.serialization import serializable
-@serializable(package="links.ordered")
+@serializable("bayesflow.links")
class Ordered(keras.Layer):
"""Activation function to link to a tensor which is monotonously increasing along a specified axis."""
diff --git a/bayesflow/links/ordered_quantiles.py b/bayesflow/links/ordered_quantiles.py
index d4f4caba2..81b2c0cc7 100644
--- a/bayesflow/links/ordered_quantiles.py
+++ b/bayesflow/links/ordered_quantiles.py
@@ -1,14 +1,14 @@
import keras
-from keras.saving import register_keras_serializable as serializable
from bayesflow.utils import layer_kwargs, logging
+from bayesflow.utils.serialization import serializable
from collections.abc import Sequence
from .ordered import Ordered
-@serializable(package="links.ordered_quantiles")
+@serializable("bayesflow.links")
class OrderedQuantiles(Ordered):
"""Activation function to link to monotonously increasing quantile estimates."""
diff --git a/bayesflow/links/positive_definite.py b/bayesflow/links/positive_definite.py
index 28c937f86..909ac2792 100644
--- a/bayesflow/links/positive_definite.py
+++ b/bayesflow/links/positive_definite.py
@@ -1,12 +1,11 @@
import keras
-from keras.saving import register_keras_serializable as serializable
-
from bayesflow.types import Tensor
from bayesflow.utils import layer_kwargs, fill_triangular_matrix
+from bayesflow.utils.serialization import serializable
-@serializable(package="bayesflow.links")
+@serializable("bayesflow.links")
class PositiveDefinite(keras.Layer):
"""Activation function to link from flat elements of a lower triangular matrix to a positive definite matrix."""
diff --git a/bayesflow/metrics/maximum_mean_discrepancy.py b/bayesflow/metrics/maximum_mean_discrepancy.py
index 37af44fd4..de4ee32f1 100644
--- a/bayesflow/metrics/maximum_mean_discrepancy.py
+++ b/bayesflow/metrics/maximum_mean_discrepancy.py
@@ -6,7 +6,7 @@
from .functional import maximum_mean_discrepancy
-@serializable
+@serializable("bayesflow.metrics")
class MaximumMeanDiscrepancy(keras.Metric):
def __init__(
self,
diff --git a/bayesflow/metrics/root_mean_squard_error.py b/bayesflow/metrics/root_mean_squard_error.py
index 97de62e6a..8827095e9 100644
--- a/bayesflow/metrics/root_mean_squard_error.py
+++ b/bayesflow/metrics/root_mean_squard_error.py
@@ -5,7 +5,7 @@
from .functional import root_mean_squared_error
-@serializable
+@serializable("bayesflow.metrics")
class RootMeanSquaredError(keras.metrics.MeanMetricWrapper):
def __init__(self, name="root_mean_squared_error", dtype=None, **kwargs):
fn = partial(root_mean_squared_error, **kwargs)
diff --git a/bayesflow/networks/consistency_models/consistency_model.py b/bayesflow/networks/consistency_models/consistency_model.py
index b8d4c56ed..8d36c1736 100644
--- a/bayesflow/networks/consistency_models/consistency_model.py
+++ b/bayesflow/networks/consistency_models/consistency_model.py
@@ -12,7 +12,7 @@
from ..inference_network import InferenceNetwork
-@serializable
+@serializable("bayesflow.networks")
class ConsistencyModel(InferenceNetwork):
"""Implements a Consistency Model with Consistency Training (CT) a described in [1-2]. The adaptations to CT
described in [2] were taken into account in our implementation for ABI [3].
diff --git a/bayesflow/networks/coupling_flow/actnorm.py b/bayesflow/networks/coupling_flow/actnorm.py
index 5221caea1..81cdc425d 100644
--- a/bayesflow/networks/coupling_flow/actnorm.py
+++ b/bayesflow/networks/coupling_flow/actnorm.py
@@ -6,7 +6,7 @@
from .invertible_layer import InvertibleLayer
-@serializable
+@serializable("bayesflow.networks")
class ActNorm(InvertibleLayer):
"""Implements an Activation Normalization (ActNorm) Layer. Activation Normalization is learned invertible
normalization, using a scale (s) and a bias (b) vector::
diff --git a/bayesflow/networks/coupling_flow/coupling_flow.py b/bayesflow/networks/coupling_flow/coupling_flow.py
index 203962b0f..28954d7d2 100644
--- a/bayesflow/networks/coupling_flow/coupling_flow.py
+++ b/bayesflow/networks/coupling_flow/coupling_flow.py
@@ -13,7 +13,7 @@
from ..inference_network import InferenceNetwork
-@serializable
+@serializable("bayesflow.networks")
class CouplingFlow(InferenceNetwork):
"""Implements a coupling flow as a sequence of dual couplings with permutations and activation
normalization. Incorporates ideas from [1-5].
diff --git a/bayesflow/networks/coupling_flow/couplings/dual_coupling.py b/bayesflow/networks/coupling_flow/couplings/dual_coupling.py
index 67db6e269..462bc02d6 100644
--- a/bayesflow/networks/coupling_flow/couplings/dual_coupling.py
+++ b/bayesflow/networks/coupling_flow/couplings/dual_coupling.py
@@ -9,7 +9,7 @@
from ..invertible_layer import InvertibleLayer
-@serializable
+@serializable("bayesflow.networks")
class DualCoupling(InvertibleLayer):
def __init__(self, subnet: str | type = "mlp", transform: str = "affine", **kwargs):
super().__init__(**kwargs)
diff --git a/bayesflow/networks/coupling_flow/couplings/single_coupling.py b/bayesflow/networks/coupling_flow/couplings/single_coupling.py
index d2703cfa7..7bd6aaf3b 100644
--- a/bayesflow/networks/coupling_flow/couplings/single_coupling.py
+++ b/bayesflow/networks/coupling_flow/couplings/single_coupling.py
@@ -8,7 +8,7 @@
from ..transforms import find_transform
-@serializable
+@serializable("bayesflow.networks")
class SingleCoupling(InvertibleLayer):
"""
Implements a single coupling layer as a composition of a subnet and a transform.
diff --git a/bayesflow/networks/coupling_flow/permutations/fixed_permutation.py b/bayesflow/networks/coupling_flow/permutations/fixed_permutation.py
index d93f27ae5..68591a172 100644
--- a/bayesflow/networks/coupling_flow/permutations/fixed_permutation.py
+++ b/bayesflow/networks/coupling_flow/permutations/fixed_permutation.py
@@ -6,7 +6,7 @@
from ..invertible_layer import InvertibleLayer
-@serializable
+@serializable("bayesflow.networks")
class FixedPermutation(InvertibleLayer):
"""
Interface class for permutations with no learnable parameters. Child classes should
diff --git a/bayesflow/networks/coupling_flow/permutations/orthogonal.py b/bayesflow/networks/coupling_flow/permutations/orthogonal.py
index a28fe7965..54bfbf901 100644
--- a/bayesflow/networks/coupling_flow/permutations/orthogonal.py
+++ b/bayesflow/networks/coupling_flow/permutations/orthogonal.py
@@ -6,7 +6,7 @@
from ..invertible_layer import InvertibleLayer
-@serializable
+@serializable("bayesflow.networks")
class OrthogonalPermutation(InvertibleLayer):
"""Implements a learnable orthogonal transformation according to [1]. Can be
used as an alternative to a fixed ``Permutation`` layer.
diff --git a/bayesflow/networks/coupling_flow/permutations/random.py b/bayesflow/networks/coupling_flow/permutations/random.py
index 82d7f39ff..522d48c63 100644
--- a/bayesflow/networks/coupling_flow/permutations/random.py
+++ b/bayesflow/networks/coupling_flow/permutations/random.py
@@ -6,7 +6,7 @@
from .fixed_permutation import FixedPermutation
-@serializable
+@serializable("bayesflow.networks")
class RandomPermutation(FixedPermutation):
# noinspection PyMethodOverriding
def build(self, xz_shape: Shape, **kwargs) -> None:
diff --git a/bayesflow/networks/coupling_flow/permutations/swap.py b/bayesflow/networks/coupling_flow/permutations/swap.py
index c5f707a1a..bb7f641b9 100644
--- a/bayesflow/networks/coupling_flow/permutations/swap.py
+++ b/bayesflow/networks/coupling_flow/permutations/swap.py
@@ -6,7 +6,7 @@
from .fixed_permutation import FixedPermutation
-@serializable
+@serializable("bayesflow.networks")
class Swap(FixedPermutation):
def build(self, xz_shape: Shape, **kwargs) -> None:
shift = xz_shape[-1] // 2
diff --git a/bayesflow/networks/coupling_flow/transforms/affine_transform.py b/bayesflow/networks/coupling_flow/transforms/affine_transform.py
index 9e8c4a9e1..1d66b0bfb 100644
--- a/bayesflow/networks/coupling_flow/transforms/affine_transform.py
+++ b/bayesflow/networks/coupling_flow/transforms/affine_transform.py
@@ -7,7 +7,7 @@
from .transform import Transform
-@serializable
+@serializable("bayesflow.networks")
class AffineTransform(Transform):
def __init__(self, clamp: bool = True, **kwargs):
super().__init__(**kwargs)
diff --git a/bayesflow/networks/coupling_flow/transforms/spline_transform.py b/bayesflow/networks/coupling_flow/transforms/spline_transform.py
index d5b0cf4b3..28b9c4415 100644
--- a/bayesflow/networks/coupling_flow/transforms/spline_transform.py
+++ b/bayesflow/networks/coupling_flow/transforms/spline_transform.py
@@ -10,7 +10,7 @@
from .transform import Transform
-@serializable
+@serializable("bayesflow.networks")
class SplineTransform(Transform):
def __init__(
self,
diff --git a/bayesflow/networks/deep_set/deep_set.py b/bayesflow/networks/deep_set/deep_set.py
index 633a1508b..9c9d0ad23 100644
--- a/bayesflow/networks/deep_set/deep_set.py
+++ b/bayesflow/networks/deep_set/deep_set.py
@@ -11,7 +11,7 @@
from ..summary_network import SummaryNetwork
-@serializable
+@serializable("bayesflow.networks")
class DeepSet(SummaryNetwork):
"""Implements a deep set encoder introduced in [1] for learning permutation-invariant representations of
set-based data, as generated by exchangeable models.
@@ -30,7 +30,7 @@ def __init__(
mlp_widths_invariant_inner: Sequence[int] = (64, 64),
mlp_widths_invariant_outer: Sequence[int] = (64, 64),
mlp_widths_invariant_last: Sequence[int] = (64, 64),
- activation: str = "gelu",
+ activation: str = "silu",
kernel_initializer: str = "he_normal",
dropout: int | float | None = 0.05,
spectral_normalization: bool = False,
@@ -72,7 +72,7 @@ def __init__(
mlp_widths_invariant_last : Sequence[int], optional
Widths of the MLP layers in the final invariant transformation. Default is (64, 64).
activation : str, optional
- Activation function used throughout the network, such as "gelu". Default is "gelu".
+ Activation function used throughout the network, such as "gelu". Default is "silu".
kernel_initializer : str, optional
Initialization strategy for kernel weights, such as "he_normal". Default is "he_normal".
dropout : int, float, or None, optional
diff --git a/bayesflow/networks/deep_set/equivariant_layer.py b/bayesflow/networks/deep_set/equivariant_layer.py
index 81bd62f58..7e35ad9bb 100644
--- a/bayesflow/networks/deep_set/equivariant_layer.py
+++ b/bayesflow/networks/deep_set/equivariant_layer.py
@@ -13,7 +13,7 @@
from .invariant_layer import InvariantLayer
-@serializable
+@serializable("bayesflow.networks")
class EquivariantLayer(keras.Layer):
"""Implements an equivariant module performing an equivariant transform.
@@ -94,6 +94,7 @@ def __init__(
kernel_initializer=kernel_initializer,
spectral_normalization=spectral_normalization,
)
+ self.out_fc_projector = keras.layers.Dense(mlp_widths_equivariant[-1], kernel_initializer=kernel_initializer)
self.layer_norm = layers.LayerNormalization() if layer_norm else None
@@ -137,7 +138,10 @@ def call(self, input_set: Tensor, training: bool = False, **kwargs) -> Tensor:
output_set = ops.concatenate([input_set, invariant_summary], axis=-1)
# Pass through final equivariant transform + residual
- output_set = input_set + self.equivariant_fc(output_set, training=training)
+ out_fc = self.equivariant_fc(output_set, training=training)
+ out_projected = self.out_fc_projector(out_fc)
+ output_set = input_set + out_projected
+
if self.layer_norm is not None:
output_set = self.layer_norm(output_set, training=training)
diff --git a/bayesflow/networks/deep_set/invariant_layer.py b/bayesflow/networks/deep_set/invariant_layer.py
index d1b6a26f9..2f29c6b8d 100644
--- a/bayesflow/networks/deep_set/invariant_layer.py
+++ b/bayesflow/networks/deep_set/invariant_layer.py
@@ -11,7 +11,7 @@
from ..mlp import MLP
-@serializable
+@serializable("bayesflow.networks")
class InvariantLayer(keras.Layer):
"""Implements an invariant module performing a permutation-invariant transform.
@@ -74,6 +74,7 @@ def __init__(
kernel_initializer=kernel_initializer,
spectral_normalization=spectral_normalization,
)
+ self.inner_projector = keras.layers.Dense(mlp_widths_inner[-1], kernel_initializer=kernel_initializer)
self.outer_fc = MLP(
mlp_widths_outer,
@@ -82,6 +83,7 @@ def __init__(
kernel_initializer=kernel_initializer,
spectral_normalization=spectral_normalization,
)
+ self.outer_projector = keras.layers.Dense(mlp_widths_outer[-1], kernel_initializer=kernel_initializer)
# Pooling function as keras layer for sum decomposition: inner( pooling( inner(set) ) )
if pooling_kwargs is None:
@@ -106,8 +108,10 @@ def call(self, input_set: Tensor, training: bool = False, **kwargs) -> Tensor:
"""
set_summary = self.inner_fc(input_set, training=training)
+ set_summary = self.inner_projector(set_summary)
set_summary = self.pooling_layer(set_summary, training=training)
set_summary = self.outer_fc(set_summary, training=training)
+ set_summary = self.outer_projector(set_summary)
return set_summary
@sanitize_input_shape
diff --git a/bayesflow/networks/embeddings/fourier_embedding.py b/bayesflow/networks/embeddings/fourier_embedding.py
index 65b5938d7..21924ee60 100644
--- a/bayesflow/networks/embeddings/fourier_embedding.py
+++ b/bayesflow/networks/embeddings/fourier_embedding.py
@@ -7,7 +7,7 @@
from bayesflow.utils.serialization import serializable
-@serializable
+@serializable("bayesflow.networks")
class FourierEmbedding(keras.Layer):
"""Implements a Fourier projection with normally distributed frequencies."""
diff --git a/bayesflow/networks/embeddings/recurrent_embedding.py b/bayesflow/networks/embeddings/recurrent_embedding.py
index df7c00f32..3fa82868d 100644
--- a/bayesflow/networks/embeddings/recurrent_embedding.py
+++ b/bayesflow/networks/embeddings/recurrent_embedding.py
@@ -6,7 +6,7 @@
from bayesflow.utils.serialization import serializable
-@serializable
+@serializable("bayesflow.networks")
class RecurrentEmbedding(keras.Layer):
"""Implements a recurrent network for flexibly embedding time vectors."""
diff --git a/bayesflow/networks/embeddings/time2vec.py b/bayesflow/networks/embeddings/time2vec.py
index 4c9c3a87f..b52ca77d8 100644
--- a/bayesflow/networks/embeddings/time2vec.py
+++ b/bayesflow/networks/embeddings/time2vec.py
@@ -5,7 +5,7 @@
from bayesflow.utils.serialization import serializable
-@serializable
+@serializable("bayesflow.networks")
class Time2Vec(keras.Layer):
"""
Implements the Time2Vec learnbale embedding from [1].
diff --git a/bayesflow/networks/flow_matching/flow_matching.py b/bayesflow/networks/flow_matching/flow_matching.py
index 3c0190467..797d4c62d 100644
--- a/bayesflow/networks/flow_matching/flow_matching.py
+++ b/bayesflow/networks/flow_matching/flow_matching.py
@@ -19,7 +19,7 @@
from ..inference_network import InferenceNetwork
-@serializable
+@serializable("bayesflow.networks")
class FlowMatching(InferenceNetwork):
"""Implements Optimal Transport Flow Matching, originally introduced as Rectified Flow, with ideas incorporated
from [1-3].
diff --git a/bayesflow/networks/mlp/mlp.py b/bayesflow/networks/mlp/mlp.py
index 1ac11fe1a..11dcdca2b 100644
--- a/bayesflow/networks/mlp/mlp.py
+++ b/bayesflow/networks/mlp/mlp.py
@@ -9,7 +9,7 @@
from ..residual import Residual
-@serializable
+@serializable("bayesflow.networks")
class MLP(keras.Sequential):
"""
Implements a simple configurable MLP with optional residual connections and dropout.
diff --git a/bayesflow/networks/point_inference_network.py b/bayesflow/networks/point_inference_network.py
index 3b1699e5a..402632355 100644
--- a/bayesflow/networks/point_inference_network.py
+++ b/bayesflow/networks/point_inference_network.py
@@ -1,17 +1,13 @@
import keras
-from keras.saving import (
- deserialize_keras_object as deserialize,
- serialize_keras_object as serialize,
- register_keras_serializable as serializable,
-)
-from bayesflow.utils import model_kwargs, find_network, serialize_value_or_type, deserialize_value_or_type
+from bayesflow.utils import model_kwargs, find_network
+from bayesflow.utils.serialization import deserialize, serializable, serialize
from bayesflow.types import Shape, Tensor
from bayesflow.scores import ScoringRule, ParametricDistributionScore
from bayesflow.utils.decorators import allow_batch_size
-@serializable(package="networks.point_inference_network")
+@serializable("bayesflow.networks")
class PointInferenceNetwork(keras.Layer):
"""Implements point estimation for user specified scoring rules by a shared feed forward architecture
with separate heads for each scoring rule.
@@ -30,10 +26,10 @@ def __init__(
self.subnet = find_network(subnet, **kwargs.get("subnet_kwargs", {}))
self.config = {
+ "subnet": serialize(subnet),
+ "scores": serialize(scores),
**kwargs,
}
- self.config = serialize_value_or_type(self.config, "subnet", subnet)
- self.config["scores"] = serialize(self.scores)
def build(self, xz_shape: Shape, conditions_shape: Shape = None) -> None:
"""Builds all network components based on shapes of conditions and targets.
@@ -119,7 +115,7 @@ def get_config(self):
def from_config(cls, config):
config = config.copy()
config["scores"] = deserialize(config["scores"])
- config = deserialize_value_or_type(config, "subnet")
+ config["subnet"] = deserialize(config["subnet"])
return cls(**config)
def call(
diff --git a/bayesflow/networks/residual/residual.py b/bayesflow/networks/residual/residual.py
index f2ca54b51..edf32782c 100644
--- a/bayesflow/networks/residual/residual.py
+++ b/bayesflow/networks/residual/residual.py
@@ -7,7 +7,7 @@
from bayesflow.utils.serialization import deserialize, serializable, serialize
-@serializable
+@serializable("bayesflow.networks")
class Residual(keras.Sequential):
def __init__(self, *layers: keras.Layer, **kwargs):
if len(layers) == 1 and isinstance(layers[0], Sequence):
diff --git a/bayesflow/networks/time_series_network/skip_recurrent.py b/bayesflow/networks/time_series_network/skip_recurrent.py
index 9b2c06c0d..23dee5156 100644
--- a/bayesflow/networks/time_series_network/skip_recurrent.py
+++ b/bayesflow/networks/time_series_network/skip_recurrent.py
@@ -6,7 +6,7 @@
from bayesflow.utils.serialization import serializable
-@serializable
+@serializable("bayesflow.networks")
class SkipRecurrentNet(keras.Layer):
"""
Implements a Skip recurrent layer as described in [1], allowing a more flexible recurrent backbone
diff --git a/bayesflow/networks/time_series_network/time_series_network.py b/bayesflow/networks/time_series_network/time_series_network.py
index 354806f6c..7a96a099d 100644
--- a/bayesflow/networks/time_series_network/time_series_network.py
+++ b/bayesflow/networks/time_series_network/time_series_network.py
@@ -7,7 +7,7 @@
from ..summary_network import SummaryNetwork
-@serializable
+@serializable("bayesflow.networks")
class TimeSeriesNetwork(SummaryNetwork):
"""
Implements a LSTNet Architecture as described in [1]
diff --git a/bayesflow/networks/transformers/fusion_transformer.py b/bayesflow/networks/transformers/fusion_transformer.py
index 1821c25d2..f416957fb 100644
--- a/bayesflow/networks/transformers/fusion_transformer.py
+++ b/bayesflow/networks/transformers/fusion_transformer.py
@@ -10,7 +10,7 @@
from .mab import MultiHeadAttentionBlock
-@serializable
+@serializable("bayesflow.networks")
class FusionTransformer(SummaryNetwork):
"""Implements a more flexible version of the TimeSeriesTransformer that applies a series of self-attention layers
followed by cross-attention between the representation and a learnable template summarized via a recurrent net."""
diff --git a/bayesflow/networks/transformers/isab.py b/bayesflow/networks/transformers/isab.py
index ae1242469..03f15a561 100644
--- a/bayesflow/networks/transformers/isab.py
+++ b/bayesflow/networks/transformers/isab.py
@@ -7,7 +7,7 @@
from .mab import MultiHeadAttentionBlock
-@serializable
+@serializable("bayesflow.networks")
class InducedSetAttentionBlock(keras.Layer):
"""Implements the ISAB block from [1] which represents learnable self-attention specifically
designed to deal with large sets via a learnable set of "inducing points".
diff --git a/bayesflow/networks/transformers/mab.py b/bayesflow/networks/transformers/mab.py
index 8f0e3f881..5bd7c9dff 100644
--- a/bayesflow/networks/transformers/mab.py
+++ b/bayesflow/networks/transformers/mab.py
@@ -8,7 +8,7 @@
from bayesflow.utils.serialization import serializable
-@serializable
+@serializable("bayesflow.networks")
class MultiHeadAttentionBlock(keras.Layer):
"""Implements the MAB block from [1] which represents learnable cross-attention.
diff --git a/bayesflow/networks/transformers/pma.py b/bayesflow/networks/transformers/pma.py
index 956c85b48..bdcb2f983 100644
--- a/bayesflow/networks/transformers/pma.py
+++ b/bayesflow/networks/transformers/pma.py
@@ -10,7 +10,7 @@
from .mab import MultiHeadAttentionBlock
-@serializable
+@serializable("bayesflow.networks")
class PoolingByMultiHeadAttention(keras.Layer):
"""Implements the pooling with multi-head attention (PMA) block from [1] which represents
a permutation-invariant encoder for set-based inputs.
diff --git a/bayesflow/networks/transformers/sab.py b/bayesflow/networks/transformers/sab.py
index a447d92a2..276383dfd 100644
--- a/bayesflow/networks/transformers/sab.py
+++ b/bayesflow/networks/transformers/sab.py
@@ -7,7 +7,7 @@
from .mab import MultiHeadAttentionBlock
-@serializable
+@serializable("bayesflow.networks")
class SetAttentionBlock(MultiHeadAttentionBlock):
"""Implements the SAB block from [1] which represents learnable self-attention.
diff --git a/bayesflow/networks/transformers/set_transformer.py b/bayesflow/networks/transformers/set_transformer.py
index 6c0ab0efc..256d9e54d 100644
--- a/bayesflow/networks/transformers/set_transformer.py
+++ b/bayesflow/networks/transformers/set_transformer.py
@@ -11,7 +11,7 @@
from .pma import PoolingByMultiHeadAttention
-@serializable
+@serializable("bayesflow.networks")
class SetTransformer(SummaryNetwork):
"""Implements the set transformer architecture from [1] which ultimately represents
a learnable permutation-invariant function. Designed to naturally model interactions in
diff --git a/bayesflow/networks/transformers/time_series_transformer.py b/bayesflow/networks/transformers/time_series_transformer.py
index 16feca444..007ae8b74 100644
--- a/bayesflow/networks/transformers/time_series_transformer.py
+++ b/bayesflow/networks/transformers/time_series_transformer.py
@@ -10,7 +10,7 @@
from .mab import MultiHeadAttentionBlock
-@serializable
+@serializable("bayesflow.networks")
class TimeSeriesTransformer(SummaryNetwork):
def __init__(
self,
diff --git a/bayesflow/scores/mean_score.py b/bayesflow/scores/mean_score.py
index 553a7c3af..0c7f200b2 100644
--- a/bayesflow/scores/mean_score.py
+++ b/bayesflow/scores/mean_score.py
@@ -1,9 +1,8 @@
-from keras.saving import register_keras_serializable as serializable
-
+from bayesflow.utils.serialization import serializable
from .normed_difference_score import NormedDifferenceScore
-@serializable(package="bayesflow.scores")
+@serializable("bayesflow.scores")
class MeanScore(NormedDifferenceScore):
r""":math:`S(\hat \theta, \theta) = | \hat \theta - \theta |^2`
diff --git a/bayesflow/scores/median_score.py b/bayesflow/scores/median_score.py
index 10c8809c3..385c47436 100644
--- a/bayesflow/scores/median_score.py
+++ b/bayesflow/scores/median_score.py
@@ -1,9 +1,8 @@
-from keras.saving import register_keras_serializable as serializable
-
+from bayesflow.utils.serialization import serializable
from .normed_difference_score import NormedDifferenceScore
-@serializable(package="bayesflow.scores")
+@serializable("bayesflow.scores")
class MedianScore(NormedDifferenceScore):
r""":math:`S(\hat \theta, \theta) = | \hat \theta - \theta |`
diff --git a/bayesflow/scores/multivariate_normal_score.py b/bayesflow/scores/multivariate_normal_score.py
index 84cfd4910..7c745919c 100644
--- a/bayesflow/scores/multivariate_normal_score.py
+++ b/bayesflow/scores/multivariate_normal_score.py
@@ -1,15 +1,15 @@
import math
import keras
-from keras.saving import register_keras_serializable as serializable
from bayesflow.types import Shape, Tensor
from bayesflow.links import PositiveDefinite
+from bayesflow.utils.serialization import serializable
from .parametric_distribution_score import ParametricDistributionScore
-@serializable(package="bayesflow.scores")
+@serializable("bayesflow.scores")
class MultivariateNormalScore(ParametricDistributionScore):
r""":math:`S(\hat p_{\mu, \Sigma}, \theta; k) = -\log( \mathcal N (\theta; \mu, \Sigma))`
diff --git a/bayesflow/scores/normed_difference_score.py b/bayesflow/scores/normed_difference_score.py
index eb2795927..d33bc128f 100644
--- a/bayesflow/scores/normed_difference_score.py
+++ b/bayesflow/scores/normed_difference_score.py
@@ -1,13 +1,13 @@
import keras
-from keras.saving import register_keras_serializable as serializable
from bayesflow.types import Shape, Tensor
from bayesflow.utils import weighted_mean
+from bayesflow.utils.serialization import serializable
from .scoring_rule import ScoringRule
-@serializable(package="bayesflow.scores")
+@serializable("bayesflow.scores")
class NormedDifferenceScore(ScoringRule):
r""":math:`S(\hat \theta, \theta; k) = | \hat \theta - \theta |^k`
diff --git a/bayesflow/scores/parametric_distribution_score.py b/bayesflow/scores/parametric_distribution_score.py
index 3ead3271f..91df32d48 100644
--- a/bayesflow/scores/parametric_distribution_score.py
+++ b/bayesflow/scores/parametric_distribution_score.py
@@ -1,12 +1,11 @@
-from keras.saving import register_keras_serializable as serializable
-
from bayesflow.types import Tensor
from bayesflow.utils import weighted_mean
+from bayesflow.utils.serialization import serializable
from .scoring_rule import ScoringRule
-@serializable(package="bayesflow.scores")
+@serializable("bayesflow.scores")
class ParametricDistributionScore(ScoringRule):
r""":math:`S(\hat p_\phi, \theta; k) = -\log(\hat p_\phi(\theta))`
diff --git a/bayesflow/scores/quantile_score.py b/bayesflow/scores/quantile_score.py
index b05a35fc5..7ba021340 100644
--- a/bayesflow/scores/quantile_score.py
+++ b/bayesflow/scores/quantile_score.py
@@ -1,16 +1,16 @@
from typing import Sequence
import keras
-from keras.saving import register_keras_serializable as serializable
from bayesflow.types import Shape, Tensor
from bayesflow.utils import logging, weighted_mean
+from bayesflow.utils.serialization import serializable
from bayesflow.links import OrderedQuantiles
from .scoring_rule import ScoringRule
-@serializable(package="bayesflow.scores")
+@serializable("bayesflow.scores")
class QuantileScore(ScoringRule):
r""":math:`S(\hat \theta_i, \theta; \tau_i)
= (\hat \theta_i - \theta)(\mathbf{1}_{\hat \theta - \theta > 0} - \tau_i)`
diff --git a/bayesflow/scores/scoring_rule.py b/bayesflow/scores/scoring_rule.py
index a1a3f5717..6dee0afec 100644
--- a/bayesflow/scores/scoring_rule.py
+++ b/bayesflow/scores/scoring_rule.py
@@ -1,13 +1,13 @@
import math
import keras
-from keras.saving import register_keras_serializable as serializable
from bayesflow.types import Shape, Tensor
-from bayesflow.utils import find_network, serialize_value_or_type, deserialize_value_or_type
+from bayesflow.utils import find_network
+from bayesflow.utils.serialization import deserialize, serializable, serialize
-@serializable(package="bayesflow.scores")
+@serializable("bayesflow.scores")
class ScoringRule:
"""Base class for scoring rules.
@@ -51,23 +51,16 @@ def __init__(
self.config = {"subnets_kwargs": self.subnets_kwargs}
def get_config(self):
- self.config["subnets"] = {
- key: serialize_value_or_type({}, "subnet", subnet) for key, subnet in self.subnets.items()
- }
- self.config["links"] = {key: serialize_value_or_type({}, "link", link) for key, link in self.links.items()}
+ self.config["subnets"] = {key: serialize(subnet) for key, subnet in self.subnets.items()}
+ self.config["links"] = {key: serialize(link) for key, link in self.links.items()}
return self.config
@classmethod
def from_config(cls, config):
config = config.copy()
- config["subnets"] = {
- key: deserialize_value_or_type(subnet_dict, "subnet")["subnet"]
- for key, subnet_dict in config["subnets"].items()
- }
- config["links"] = {
- key: deserialize_value_or_type(link_dict, "link")["link"] for key, link_dict in config["links"].items()
- }
+ config["subnets"] = {key: deserialize(subnet_dict) for key, subnet_dict in config["subnets"].items()}
+ config["links"] = {key: deserialize(link_dict) for key, link_dict in config["links"].items()}
return cls(**config)
diff --git a/bayesflow/utils/__init__.py b/bayesflow/utils/__init__.py
index 737c533ce..47ab771ff 100644
--- a/bayesflow/utils/__init__.py
+++ b/bayesflow/utils/__init__.py
@@ -6,6 +6,7 @@
keras_utils,
logging,
numpy_utils,
+ serialization,
)
from .callbacks import detailed_loss_callback
@@ -104,4 +105,4 @@
from ._docs import _add_imports_to_all
-_add_imports_to_all(include_modules=["keras_utils", "logging", "numpy_utils"])
+_add_imports_to_all(include_modules=["keras_utils", "logging", "numpy_utils", "serialization"])
diff --git a/bayesflow/utils/decorators.py b/bayesflow/utils/decorators.py
index 7fd32edc9..1283fe66a 100644
--- a/bayesflow/utils/decorators.py
+++ b/bayesflow/utils/decorators.py
@@ -17,6 +17,7 @@ def allow_args(fn: Decorator) -> Decorator:
def wrapper(f: Fn) -> Fn: ...
@overload
def wrapper(*fargs: any, **fkwargs: any) -> Fn: ...
+ @wraps(fn)
def wrapper(*fargs: any, **fkwargs: any) -> Fn:
if len(fargs) == 1 and not fkwargs and callable(fargs[0]):
# called without arguments
diff --git a/bayesflow/utils/serialization.py b/bayesflow/utils/serialization.py
index 500264f05..5be0e0e1d 100644
--- a/bayesflow/utils/serialization.py
+++ b/bayesflow/utils/serialization.py
@@ -5,6 +5,7 @@
import keras
import numpy as np
import sys
+from warnings import warn
# this import needs to be exactly like this to work with monkey patching
from keras.saving import deserialize_keras_object
@@ -19,111 +20,125 @@
def serialize_value_or_type(config, name, obj):
- """Serialize an object that can be either a value or a type
- and add it to a copy of the supplied dictionary.
+ """This function is deprecated."""
+ warn(
+ "This method is deprecated. It was replaced by bayesflow.utils.serialization.serialize.",
+ DeprecationWarning,
+ stacklevel=2,
+ )
- Parameters
- ----------
- config : dict
- Dictionary to add the serialized object to. This function does not
- modify the dictionary in place, but returns a modified copy.
- name : str
- Name of the obj that should be stored. Required for later deserialization.
- obj : object or type
- The object to serialize. If `obj` is of type `type`, we use
- `keras.saving.get_registered_name` to obtain the registered type name.
- If it is not a type, we try to serialize it as a Keras object.
- Returns
- -------
- updated_config : dict
- Updated dictionary with a new key `"_bayesflow_
_type"` or
- `"_bayesflow__val"`. The prefix is used to avoid name collisions,
- the suffix indicates how the stored value has to be deserialized.
-
- Notes
- -----
- We allow strings or `type` parameters at several places to instantiate objects
- of a given type (e.g., `subnet` in `CouplingFlow`). As `type` objects cannot
- be serialized, we have to distinguish the two cases for serialization and
- deserialization. This function is a helper function to standardize and
- simplify this.
- """
- updated_config = config.copy()
- if isinstance(obj, type):
- updated_config[f"{PREFIX}{name}_type"] = keras.saving.get_registered_name(obj)
- else:
- updated_config[f"{PREFIX}{name}_val"] = keras.saving.serialize_keras_object(obj)
- return updated_config
+def deserialize_value_or_type(config, name):
+ """This function is deprecated."""
+ warn(
+ "This method is deprecated. It was replaced by bayesflow.utils.serialization.deserialize.",
+ DeprecationWarning,
+ stacklevel=2,
+ )
-def deserialize_value_or_type(config, name):
- """Deserialize an object that can be either a value or a type and add
- it to the supplied dictionary.
+def deserialize(config: dict, custom_objects=None, safe_mode=True, **kwargs):
+ """Deserialize an object serialized with :py:func:`serialize`.
+
+ Wrapper function around `keras.saving.deserialize_keras_object` to enable deserialization of
+ classes.
Parameters
----------
- config : dict
- Dictionary containing the object to deserialize. If a type was
- serialized, it should contain the key `"_bayesflow__type"`.
- If an object was serialized, it should contain the key
- `"_bayesflow__val"`. In a copy of this dictionary,
- the item will be replaced with the key `name`.
- name : str
- Name of the object to deserialize.
+ config : dict
+ Python dict describing the object.
+ custom_objects : dict, optional
+ Python dict containing a mapping between custom object names and the corresponding
+ classes or functions. Forwarded to `keras.saving.deserialize_keras_object`.
+ safe_mode : bool, optional
+ Boolean, whether to disallow unsafe lambda deserialization. When safe_mode=False,
+ loading an object has the potential to trigger arbitrary code execution. This argument
+ is only applicable to the Keras v3 model format. Defaults to True.
+ Forwarded to `keras.saving.deserialize_keras_object`.
Returns
-------
- updated_config : dict
- Updated dictionary with a new key `name`, with a value that is either
- a type or an object.
+ obj :
+ The object described by the config dictionary.
+
+ Raises
+ ------
+ ValueError
+ If a type in the config can not be deserialized.
See Also
--------
- serialize_value_or_type
+ serialize
"""
- updated_config = config.copy()
- if f"{PREFIX}{name}_type" in config:
- updated_config[name] = keras.saving.get_registered_object(config[f"{PREFIX}{name}_type"])
- del updated_config[f"{PREFIX}{name}_type"]
- elif f"{PREFIX}{name}_val" in config:
- updated_config[name] = keras.saving.deserialize_keras_object(config[f"{PREFIX}{name}_val"])
- del updated_config[f"{PREFIX}{name}_val"]
- return updated_config
-
-
-def deserialize(obj, custom_objects=None, safe_mode=True, **kwargs):
with monkey_patch(deserialize_keras_object, deserialize) as original_deserialize:
- if isinstance(obj, str) and obj.startswith(_type_prefix):
+ if isinstance(config, str) and config.startswith(_type_prefix):
# we marked this as a type during serialization
- obj = obj[len(_type_prefix) :]
+ config = config[len(_type_prefix) :]
tp = keras.saving.get_registered_object(
# TODO: can we pass module objects without overwriting numpy's dict with builtins?
- obj,
+ config,
custom_objects=custom_objects,
module_objects=np.__dict__ | builtins.__dict__,
)
if tp is None:
raise ValueError(
- f"Could not deserialize type {obj!r}. Make sure it is registered with "
+ f"Could not deserialize type {config!r}. Make sure it is registered with "
f"`keras.saving.register_keras_serializable` or pass it in `custom_objects`."
)
return tp
- if inspect.isclass(obj):
+ if inspect.isclass(config):
# add this base case since keras does not cover it
- return obj
+ return config
- obj = original_deserialize(obj, custom_objects=custom_objects, safe_mode=safe_mode, **kwargs)
+ obj = original_deserialize(config, custom_objects=custom_objects, safe_mode=safe_mode, **kwargs)
return obj
@allow_args
-def serializable(cls, package=None, name=None):
- if package is None:
- frame = sys._getframe(1)
+def serializable(cls, package: str, name: str | None = None, disable_module_check: bool = False):
+ """Register class as Keras serializable.
+
+ Wrapper function around `keras.saving.register_keras_serializable` to automatically check consistency
+ of the supplied `package` argument with the module a class resides in. The `package` name should generally
+ be the module the class resides in, truncated at depth two. Valid examples would be "bayesflow.networks"
+ or "bayesflow.adapters". The check can be disabled if necessary by setting `disable_module_check` to True.
+ This should only be done in exceptional cases, and accompanied by a comment why it is necessary for a given
+ class.
+
+ Parameters
+ ----------
+ cls : type
+ The class to register.
+ package : str
+ `package` argument forwarded to `keras.saving.register_keras_serializable`.
+ Should generally correspond to the module of the class, truncated at depth two (e.g., "bayesflow.networks").
+ name : str, optional
+ `name` argument forwarded to `keras.saving.register_keras_serializable`.
+ If None is provided, the classe's __name__ attribute is used.
+ disable_module_check : bool, optional
+ Disable check that the provided `package` is consistent with the location of the class within the library.
+
+ Raises
+ ------
+ ValueError
+ If the supplied `package` does not correspond to the module of the class, truncated at depth two, and
+ `disable_module_check` is False. No error is thrown when a class is not part of the bayesflow module.
+ """
+ if not disable_module_check:
+ frame = sys._getframe(2)
g = frame.f_globals
- package = g.get("__name__", "bayesflow")
+ module_name = g.get("__name__", "")
+ # only apply this check if the class is inside the bayesflow module
+ is_bayesflow = module_name.split(".")[0] == "bayesflow"
+ auto_package = ".".join(module_name.split(".")[:2])
+ if is_bayesflow and package != auto_package:
+ raise ValueError(
+ "'package' should be the first two levels of the module the class resides in (e.g., bayesflow.networks)"
+ f'. In this case it should be \'package="{auto_package}"\' (was "{package}"). If this is not possible'
+ " (e.g., because a class was moved to a different module, and serializability should be preserved),"
+ " please set 'disable_module_check=True' and add a comment why it is necessary for this class."
+ )
if name is None:
name = copy(cls.__name__)
@@ -133,6 +148,26 @@ def serializable(cls, package=None, name=None):
def serialize(obj):
+ """Serialize an object using Keras.
+
+ Wrapper function around `keras.saving.serialize_keras_object`, which adds the
+ ability to serialize classes.
+
+ Parameters
+ ----------
+ object : Keras serializable object, or class
+ The object to serialize
+
+ Returns
+ -------
+ config : dict
+ A python dict that represents the object. The python dict can be deserialized via
+ :py:func:`deserialize`.
+
+ See Also
+ --------
+ deserialize
+ """
if isinstance(obj, (tuple, list, dict)):
return keras.tree.map_structure(serialize, obj)
elif inspect.isclass(obj):
diff --git a/bayesflow/wrappers/mamba/mamba.py b/bayesflow/wrappers/mamba/mamba.py
index d06508790..b328ede98 100644
--- a/bayesflow/wrappers/mamba/mamba.py
+++ b/bayesflow/wrappers/mamba/mamba.py
@@ -9,7 +9,7 @@
from .mamba_block import MambaBlock
-@serializable
+@serializable("bayesflow.wrappers")
class Mamba(SummaryNetwork):
"""
Wraps a sequence of Mamba modules using the simple Mamba module from:
diff --git a/bayesflow/wrappers/mamba/mamba_block.py b/bayesflow/wrappers/mamba/mamba_block.py
index b8ba36d2e..bd15ecc29 100644
--- a/bayesflow/wrappers/mamba/mamba_block.py
+++ b/bayesflow/wrappers/mamba/mamba_block.py
@@ -6,7 +6,7 @@
from bayesflow.utils.serialization import serializable
-@serializable
+@serializable("bayesflow.wrappers")
class MambaBlock(keras.Layer):
"""
Wraps the original Mamba module from, with added functionality for bidirectional processing:
diff --git a/docsrc/source/development/index.md b/docsrc/source/development/index.md
index c62971532..adbadf21f 100644
--- a/docsrc/source/development/index.md
+++ b/docsrc/source/development/index.md
@@ -1,87 +1,23 @@
-# Patterns & Caveats
+# Developer Documentation
-**Note**: This document is part of BayesFlow's developer documentation, and
+**Attention:** You are looking BayesFlow's developer documentation, which is
aimed at people who want to extend or improve BayesFlow. For user documentation,
-please refer to the examples and the public API documentation.
+please refer to the {doc}`../examples` and the {doc}`../api/bayesflow`.
-## Introduction
-
-From version 2 on, BayesFlow is built on [Keras](https://keras.io/) v3, which
-allows writing machine learning pipelines that run in JAX, TensorFlow and PyTorch.
-By using functionality provided by Keras, and extending it with backend-specific
-code where necessary, we aim to build BayesFlow in a backend-agnostic fashion as
-well.
-
-As Keras is built upon three different backend, each with different functionality
-and design decisions, it has its own quirks and compromises. This documents
-outlines some of them, along with the design decisions and programming patterns
-we use to counter them.
-
-This document is work in progress, so if you read through the code base and
+This section is work in progress, so if you read through the code base and
encounter something that looks odd, but shows up in multiple places, please open
an issue so that we can add it here. Also, if you introduce a new pattern that
others will have to use in the future as well, please document it here, along
with some background information on why it is necessary and how to use it in
practice.
-## Privileged `training` argument in the `call()` method cannot be passed via `kwargs`
-
-For layers that have different behavior at training and inference time (e.g.,
-dropout or batch normalization layers), a boolean `training` argument can be
-exposed, see [this section of the Keras documentation](https://keras.io/guides/making_new_layers_and_models_via_subclassing/#privileged-training-argument-in-the-call-method).
-If we want to pass this manually, we have to do so explicitly and not as part
-of a set of keyword arguments via `**kwargs`.
-
-@Lars: Maybe you can add more details on what is going on behind the scenes.
-
-## Serialization
-
-Serialization deals with the problem of storing objects to disk, and loading
-them at a later point in time. This is straight-forward for data structures like
-numpy arrays, but for classes with custom behavior, like approximators or neural
-network layers, it is somewhat more complex.
-
-Please refer to the Keras guide [Save, serialize, and export models](https://keras.io/guides/serialization_and_saving/)
-for an introduction, and [Customizing Saving and Serialization](https://keras.io/guides/customizing_saving_and_serialization/)
-for advanced concepts.
-
-The basic idea is: by storing the arguments of the constructor of a class
-(i.e., the arguments of the `__init__` function), we can later construct an
-object identical to the one we have stored, except for the weights.
-As the structure is identical, we can then map the stored weights to the newly
-constructed object. The caveat is that all arguments have to be either basic
-Python objects (like int, float, string, bool, ...) or themselves serializable.
-If they are not, we have to manually specify how to serialize them, and how to
-load them later on.
-
-### Registering classes as serializable
-
-TODO
-
-### Serialization of custom types
-
-In BayesFlow, we often encounter situations where we do not want to pass a
-specific object (e.g., an MPL of a certain size), but we want to pass its type
-(MLP) and the arguments to construct it. With the type and the arguments, we can
-then construct multiple instances of the network in different places, for example
-as the network inside a coupling block.
-
-Unfortunately, `type` is not Keras serializable, so we have to serialize those
-arguments manually. To complicate matters further, we also allow passing a string
-instead of a type, which is then used to select the correct type.
-
-To make it more concrete, we look at the `CouplingFlow` class, which takes the
-argument `subnet` that provide the type of the subnet. It is either a
-string (e.g., `"mlp"`) or a class (e.g., `bayesflow.networks.MLP`). In the first
-case, we can just store the value and load it, in the latter case, we first have
-to convert the type to a string that we can later convert back into a type.
-
-We provide two helper functions that can deal with both cases:
-`bayesflow.utils.serialize_value_or_type(config, name, obj)` and
-`bayesflow.utils.deserialize_value_or_type(config, name)`.
-In `get_config`, we use the first to store the object, whereas we use the
-latter in `from_config` to load it again.
+```{toctree}
+:maxdepth: 1
+:titlesonly:
+:numbered:
-As we need all arguments to `__init__` in `get_config`, it can make sense to
-build a `config` dictionary in `__init__` already, which can then be stored when
-`get_config` is called. Take a look at `CouplingFlow` for an example of that.
+introduction
+pitfalls
+stages
+serialization
+```
diff --git a/docsrc/source/development/introduction.md b/docsrc/source/development/introduction.md
new file mode 100644
index 000000000..a60830c2a
--- /dev/null
+++ b/docsrc/source/development/introduction.md
@@ -0,0 +1,12 @@
+# Introduction
+
+From version 2 on, BayesFlow is built on [Keras3](https://keras.io/), which
+allows writing machine learning pipelines that run in JAX, TensorFlow and PyTorch.
+By using functionality provided by Keras, and extending it with backend-specific
+code where necessary, we aim to build BayesFlow in a backend-agnostic fashion as
+well.
+
+As Keras is built upon three different backends, each with different functionality
+and design decisions, it comes with its own quirks and compromises. The following documents
+outline some of them, along with the design decisions and programming patterns
+we use to counter them.
diff --git a/docsrc/source/development/pitfalls.md b/docsrc/source/development/pitfalls.md
new file mode 100644
index 000000000..69d183ec1
--- /dev/null
+++ b/docsrc/source/development/pitfalls.md
@@ -0,0 +1,13 @@
+# Potential Pitfalls
+
+This document covers things we have learned during development that might cause problems or hard to find bugs.
+
+## Privileged `training` argument in the `call()` method cannot be passed via `kwargs`
+
+For layers that have different behavior at training and inference time (e.g.,
+dropout or batch normalization layers), a boolean `training` argument can be
+exposed, see [this section of the Keras documentation](https://keras.io/guides/making_new_layers_and_models_via_subclassing/#privileged-training-argument-in-the-call-method).
+If we want to pass this manually, we have to do so explicitly and not as part
+of a set of keyword arguments via `**kwargs`.
+
+@Lars: Maybe you can add more details on what is going on behind the scenes.
diff --git a/docsrc/source/development/serialization.md b/docsrc/source/development/serialization.md
new file mode 100644
index 000000000..ec8988454
--- /dev/null
+++ b/docsrc/source/development/serialization.md
@@ -0,0 +1,35 @@
+# Serialization: Enable Model Saving & Loading
+
+Serialization deals with the problem of storing objects to disk, and loading them at a later point in time.
+This is straight-forward for data structures like numpy arrays, but for classes with custom behavior it is somewhat more complex.
+
+Please refer to the Keras guide [Save, serialize, and export models](https://keras.io/guides/serialization_and_saving/) for an introduction, and [Customizing Saving and Serialization](https://keras.io/guides/customizing_saving_and_serialization/) for advanced concepts.
+
+The basic idea is: by storing the arguments of the constructor of a class (i.e., the arguments of the `__init__` function), we can later construct an object similar to the one we have stored, except for the weights and other stateful content.
+As the structure is identical, we can then map the stored weights to the newly constructed object.
+The caveat is that all arguments have to be either basic Python objects (like int, float, string, bool, ...) or themselves serializable.
+If they are not, we have to manually specify how to serialize them, and how to load them later on.
+One important example is that types are not serializable.
+As we want/need to pass them in some places, we have to resort to some custom behavior, that is described below.
+
+## Serialization Utilities
+
+BayesFlows serialization utilities can be found in the {py:mod}`~bayesflow.utils.serialization` module.
+We mainly provide three convenience functions:
+
+- The {py:func}`~bayesflow.utils.serialization.serializable` decorator wraps the `keras.saving.register_keras_serializable` function to ensure consistent naming of the `package` argument within the library.
+- The {py:func}`~bayesflow.utils.serialization.serialize` function, which adds support for serializing classes.
+- Its counterpart {py:func}`~bayesflow.utils.serialization.deserialize`, adds support to deserialize classes.
+
+## Usage
+
+To use the adapted serialization functions, you have to use them in the `get_config` and `from_config` method. Please refer to existing classes in the library for usage examples.
+
+### The `serializable` Decorator
+
+To make serialization as little confusing as possible, as well as providing stability even when moving classes around, we provide the `package` argument explicitly for each class.
+The naming should respect the following naming scheme: Take the module the class resides in (for example, `bayesflow.adapters.transforms.standardize`), and truncate the path to depth two (`bayesflow.adapters`).
+In cases where this convention cannot be followed, set `disable_module_check` to `True`, and describe why a different name was necessary.
+Changing `package` breaks backwards-compatibility for serialization, so it should be avoided whenever possible.
+If you move a class to a different module (without changing the class itself), keep the `package` and set `disable_module_check` to `True`.
+This may later be adapted in a release that breaks backward compatiblity anyways.
diff --git a/docsrc/source/development/stages.md b/docsrc/source/development/stages.md
new file mode 100644
index 000000000..e9aa4ad8f
--- /dev/null
+++ b/docsrc/source/development/stages.md
@@ -0,0 +1,8 @@
+# Stages
+
+To keep track of the phase each functionality is called in, we provide a `stage` parameter.
+There are three stages:
+
+- `training`: The stage to train approximator (and related stateful objects, like the adapter)
+- `validation`: Identical setting to `training`, but calls in this stage should _not_ change the approximator
+- `inference`: Calls in this change should not change the approximator. In addition, the input structure might be different compared to the training phase. For example for sampling, we only provide `summary_conditions` and `inference_conditions`, but not the `inference_variables`, which we want to infer.
diff --git a/docsrc/source/index.md b/docsrc/source/index.md
index f89c5ff3f..7edaa4b9f 100644
--- a/docsrc/source/index.md
+++ b/docsrc/source/index.md
@@ -10,6 +10,9 @@ It provides users and researchers with:
BayesFlow (version 2+) is designed to be a flexible and efficient tool that enables rapid statistical inference
fueled by continuous progress in generative AI and Bayesian inference.
+To access the documentation for [BayesFlow version 1.x](https://github.com/bayesflow-org/bayesflow/tree/stable-legacy), select `stable-legacy` in the version picker above.
+For advice on the migration from version 1.x to version 2+, please refer to the [README](https://github.com/bayesflow-org/bayesflow/blob/main/README.md).
+
## Conceptual Overview
@@ -45,15 +48,7 @@ history = workflow.fit_online(epochs=50, batch_size=32, num_batches_per_epoch=50
diagnostics = workflow.plot_default_diagnostics(test_data=300)
```
-For an in-depth exposition, check out our walkthrough notebooks below.
-
-1. [Linear regression starter example](_examples/Linear_Regression_Starter.ipynb)
-2. [From ABC to BayesFlow](_examples/From_ABC_to_BayesFlow.ipynb)
-3. [Two moons starter example](_examples/Two_Moons_Starter.ipynb)
-4. [Rapid iteration with point estimators](_examples/Lotka_Volterra_point_estimation_and_expert_stats.ipynb)
-5. [SIR model with custom summary network](_examples/SIR_Posterior_Estimation.ipynb)
-6. [Bayesian experimental design](_examples/Bayesian_Experimental_Design.ipynb)
-7. [Simple model comparison example](_examples/One_Sample_TTest.ipynb)
+For an in-depth exposition, check out our walkthrough notebooks in the {doc}`Examples <../examples>` section.
More tutorials are always welcome! Please consider making a pull request if you have a cool application that you want to contribute.
diff --git a/examples/From_BayesFlow_1.1_to_2.0.ipynb b/examples/From_BayesFlow_1.1_to_2.0.ipynb
index fa0e13876..b8090fa79 100644
--- a/examples/From_BayesFlow_1.1_to_2.0.ipynb
+++ b/examples/From_BayesFlow_1.1_to_2.0.ipynb
@@ -15,9 +15,9 @@
"cell_type": "markdown",
"metadata": {},
"source": [
- "Long-time users of BayesFlow will notice that with the update to version 2.0 many things have changed. This short guide aims to clarify some of those changes. Users familiar with the previous Quickstart guide will notice that this notebook follows a similar structure, but assumes that users are already familiar with BayesFlow. We omit many of the the mathematical explanations in favor of demonstrating the differences in workflow. For a more detailed explanation of the BayesFlow framework, users should read, for example, the Starter Notebook on Bayesian Linear Regression.\n",
+ "Long-time users of BayesFlow will notice that with the update to version 2.0 many things have changed. This short guide aims to clarify some of those changes. Users familiar with the previous Quickstart guide will notice that this notebook follows a similar structure, but assumes that users are already familiar with BayesFlow. We omit many of the mathematical explanations in favor of demonstrating the differences in workflow. For a more detailed explanation of the BayesFlow framework, users should read, for example, the Starter Notebook on Bayesian Linear Regression.\n",
"\n",
- "Additionally to avoid confusion, similarly named objects from _BayesFlow v1.1_ will have `1.1` after their name, whereas those from _BayesFlow v2.0_ will not. Finally, a short table with a summary of the function call changes is provided at the end of the guide. "
+ "Additionally, to avoid confusion, similarly named objects from _BayesFlow v1.1_ will have `1.1` after their name, whereas those from _BayesFlow v2.0_ will not. Finally, a short table with a summary of the function call changes is provided at the end of the guide. "
]
},
{
@@ -26,7 +26,7 @@
"source": [
"## Keras Framework\n",
"\n",
- "BayesFlow 2.0 looks quite different from BayesFlow 1.1 because the library was refactored to replace the old backend, TensorFlow, with the new [Keras](https://keras.io) API. Users can now choose their preferred backend among the machine learning frameworks `TensorFlow`, `JAX` and `PyTorch`."
+ "BayesFlow 2.0 looks quite different from BayesFlow 1.1 because the library was refactored to replace the old backend, TensorFlow, with the new [Keras3](https://keras.io) API. Users can now choose their preferred backend among the machine learning frameworks [TensorFlow](https://github.com/tensorflow/tensorflow), [JAX](https://github.com/google/jax) and [PyTorch](https://github.com/pytorch/pytorch)."
]
},
{
@@ -52,7 +52,7 @@
"cell_type": "markdown",
"metadata": {},
"source": [
- "In general, BayesFlow 2.0 relies much more on dictionaries since parameters are now named by convention. Many objects now expect a dictionary, and parameters and data are returned as dictionaries as well. "
+ "In general, BayesFlow 2.0 relies much more on dictionaries since parameters and data are now named by convention. Many objects now expect a dictionary, and parameters and data are returned as dictionaries as well. "
]
},
{
@@ -67,7 +67,7 @@
"\n",
"Previously users would define a prior function and pass it to a `Prior1.1` object to sample prior values. The likelihood would also be specified via a function and passed to a `Simulator1.1` wrapper to produce observations for given parameter values. These were then combined in the `GenerativeModel1.1`. \n",
"\n",
- "In 2.0 we no longer make use of the `Prior1.1`, `Simulator1.1` or `GenerativeModel1.1` objects. Instead, the `Simulator` class comprises the whole functionality, taking the role of the `GenerativeModel1.1`. It directly produces joint samples from prior and likelihood, without creating separate `Prior1.1` and `Simulator1.1` objects first. The `bf.simulator.make_simulator` offers a convenient wrapper to create the appropriate simulator for different settings."
+ "In 2.0 we no longer make use of the `Prior1.1`, `Simulator1.1` or `GenerativeModel1.1` objects. Instead, the `Simulator` class comprises the whole functionality, taking the role of the `GenerativeModel1.1`. It directly produces joint samples from prior and likelihood, without creating separate `Prior1.1` and `Simulator1.1` objects first. The `bf.make_simulator` offers a convenient wrapper to create the appropriate simulator for different settings."
]
},
{
@@ -217,7 +217,7 @@
"cell_type": "markdown",
"metadata": {},
"source": [
- "For the inference network there are now several implemented architectures for users to choose from. They are `FlowMatching`, `ConsistencyModel`, and `CouplingFlow`. For this demonstration we use `FlowMatching`, for explanations on the different models please refer to the other examples and the API documentation. "
+ "The previous version only featured the `InvertibleNetwork1.1` class as inference network. As the generative modeling landscape has evolved since then, there are now several implemented architectures for users to choose from. `CouplingFlow` features an architecture similar to `InvertibleNetwork1.1`. `FlowMatching` is a continuous flow architecture, which can express more complex distributions, at the cost of higher inference time. For this demonstration we use `FlowMatching`. For explanations on the different models please refer to the other examples and the API documentation. "
]
},
{
@@ -377,10 +377,11 @@
"| BayesFlow v1.1 | BayesFlow v2.0 usage |\n",
"| :--------| :---------| \n",
"| `Prior`, `Simulator` | Defunct and no longer standalone objects but incorporated into `bf.simulators.Simulator` | \n",
- "|`GenerativeModel` | Defunct with it's functionality having been taken over by `bf.simulators.make_simulator` | \n",
+ "| `GenerativeModel` | Defunct with it's functionality having been taken over by `bf.make_simulator` | \n",
"| `training.configurator` | Functionality taken over by `bf.adapters.Adapter` | \n",
"|`Trainer` | Functionality taken over by `fit` method of `bf.approximators.Approximator` | \n",
- "| `AmortizedPosterior`, `AmortizedLikelihood` | Functionality taken over by `ContinuousApproximator` | "
+ "| `AmortizedPosterior`, `AmortizedLikelihood` | Functionality taken over by `ContinuousApproximator` |\n",
+ "| `InvertibleNetwork` | Functionality taken over by `CouplingFlow`, but also other networks that are subclasses of `InferenceNetwork`, e.g. `FlowMatching` |"
]
}
],
diff --git a/examples/SIR_Posterior_Estimation.ipynb b/examples/SIR_Posterior_Estimation.ipynb
index 7963d00e5..2e3f94bd9 100644
--- a/examples/SIR_Posterior_Estimation.ipynb
+++ b/examples/SIR_Posterior_Estimation.ipynb
@@ -84,7 +84,7 @@
"id": "39846c15b88eaf8e",
"metadata": {},
"source": [
- "As described in our [very first notebook](linear_Regression_Starter.ipynb), a generative model consists of a prior (encoding suitable parameter ranges) and a simulator (generating data given simulations). Our underlying model distinguishes between susceptible, $S$, infected, $I$, and recovered, $R$, individuals with infection and recovery occurring at a constant transmission rate $\\lambda$ and constant recovery rate $\\mu$, respectively. The model dynamics are governed by the following system of ODEs:\n",
+ "As described in our [very first notebook](Linear_Regression_Starter.ipynb), a generative model consists of a prior (encoding suitable parameter ranges) and a simulator (generating data given simulations). Our underlying model distinguishes between susceptible, $S$, infected, $I$, and recovered, $R$, individuals with infection and recovery occurring at a constant transmission rate $\\lambda$ and constant recovery rate $\\mu$, respectively. The model dynamics are governed by the following system of ODEs:\n",
"\n",
"$$\n",
"\\begin{align}\n",
diff --git a/pyproject.toml b/pyproject.toml
index a42b1df97..07a5e924b 100644
--- a/pyproject.toml
+++ b/pyproject.toml
@@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta"
[project]
name = "bayesflow"
-version = "2.0.2"
+version = "2.0.3"
authors = [{ name = "The BayesFlow Team" }]
classifiers = [
"Development Status :: 5 - Production/Stable",
diff --git a/tests/test_adapters/conftest.py b/tests/test_adapters/conftest.py
index 873279f09..d69cd4be4 100644
--- a/tests/test_adapters/conftest.py
+++ b/tests/test_adapters/conftest.py
@@ -49,6 +49,8 @@ def random_data():
"z1": np.random.standard_normal(size=(32, 2)),
"p1": np.random.lognormal(size=(32, 2)),
"p2": np.random.lognormal(size=(32, 2)),
+ "p3": np.random.lognormal(size=(32, 2)),
+ "n1": 1 - np.random.lognormal(size=(32, 2)),
"s1": np.random.standard_normal(size=(32, 3, 2)),
"s2": np.random.standard_normal(size=(32, 3, 2)),
"t1": np.zeros((3, 2)),
@@ -56,5 +58,43 @@ def random_data():
"d1": np.random.standard_normal(size=(32, 2)),
"d2": np.random.standard_normal(size=(32, 2)),
"o1": np.random.randint(0, 9, size=(32, 2)),
+ "u1": np.random.uniform(low=-1, high=2, size=(32, 1)),
"key_to_split": np.random.standard_normal(size=(32, 10)),
}
+
+
+@pytest.fixture()
+def adapter_log_det_jac():
+ from bayesflow.adapters import Adapter
+
+ adapter = (
+ Adapter()
+ .scale("x1", by=2)
+ .log("p1", p1=True)
+ .sqrt("p2")
+ .constrain("p3", lower=0)
+ .constrain("n1", upper=1)
+ .constrain("u1", lower=-1, upper=2)
+ .concatenate(["p1", "p2", "p3"], into="p")
+ .rename("u1", "u")
+ )
+
+ return adapter
+
+
+@pytest.fixture()
+def adapter_log_det_jac_inverse():
+ from bayesflow.adapters import Adapter
+
+ adapter = (
+ Adapter()
+ .standardize("x1", mean=1, std=2)
+ .log("p1")
+ .sqrt("p2")
+ .constrain("p3", lower=0, method="log")
+ .constrain("n1", upper=1, method="log")
+ .constrain("u1", lower=-1, upper=2)
+ .scale(["p1", "p2", "p3"], by=3.5)
+ )
+
+ return adapter
diff --git a/tests/test_adapters/test_adapters.py b/tests/test_adapters/test_adapters.py
index d6215170e..1784befb7 100644
--- a/tests/test_adapters/test_adapters.py
+++ b/tests/test_adapters/test_adapters.py
@@ -13,7 +13,7 @@ def test_cycle_consistency(adapter, random_data):
deprocessed = adapter(processed, inverse=True)
for key, value in random_data.items():
- if key in ["d1", "d2"]:
+ if key in ["d1", "d2", "p3", "n1", "u1"]:
# dropped
continue
assert key in deprocessed
@@ -230,3 +230,62 @@ def test_to_dict_transform():
# category should have 5 one-hot categories, even though it was only passed 4
assert processed["category"].shape[-1] == 5
+
+
+def test_log_det_jac(adapter_log_det_jac, random_data):
+ d, log_det_jac = adapter_log_det_jac(random_data, log_det_jac=True)
+
+ assert np.allclose(log_det_jac["x1"], np.log(2))
+
+ p1 = -np.log1p(random_data["p1"])
+ p2 = -0.5 * np.log(random_data["p2"]) - np.log(2)
+ p3 = random_data["p3"] - np.log(np.exp(random_data["p3"]) - 1)
+ p = np.sum(p1, axis=-1) + np.sum(p2, axis=-1) + np.sum(p3, axis=-1)
+
+ assert np.allclose(log_det_jac["p"], p)
+
+ n1 = -(random_data["n1"] - 1)
+ n1 = n1 - np.log(np.exp(n1) - 1)
+ n1 = np.sum(n1, axis=-1)
+
+ assert np.allclose(log_det_jac["n1"], n1)
+
+ u1 = random_data["u1"]
+ u1 = (u1 + 1) / 3
+ u1 = -np.log(u1) - np.log1p(-u1) - np.log(3)
+
+ assert np.allclose(log_det_jac["u"], u1[:, 0])
+
+
+def test_log_det_jac_inverse(adapter_log_det_jac_inverse, random_data):
+ d, forward_log_det_jac = adapter_log_det_jac_inverse(random_data, log_det_jac=True)
+ d, inverse_log_det_jac = adapter_log_det_jac_inverse(d, inverse=True, log_det_jac=True)
+
+ for key in forward_log_det_jac.keys():
+ assert np.allclose(forward_log_det_jac[key], -inverse_log_det_jac[key])
+
+
+def test_log_det_jac_exceptions(random_data):
+ # Test cannot compute inverse log_det_jac
+ # e.g., when we apply a concat and then a transform that
+ # we cannot "unconcatenate" the log_det_jac
+ # (because the log_det_jac are summed, not concatenated)
+ adapter = bf.Adapter().concatenate(["p1", "p2", "p3"], into="p").sqrt("p")
+ transformed_data, log_det_jac = adapter(random_data, log_det_jac=True)
+
+ # test that inverse raises error
+ with pytest.raises(ValueError):
+ adapter(transformed_data, inverse=True, log_det_jac=True)
+
+ # test resolvable order: first transform, then concatenate
+ adapter = bf.Adapter().sqrt(["p1", "p2", "p3"]).concatenate(["p1", "p2", "p3"], into="p")
+
+ transformed_data, forward_log_det_jac = adapter(random_data, log_det_jac=True)
+ data, inverse_log_det_jac = adapter(transformed_data, inverse=True, log_det_jac=True)
+ inverse_log_det_jac = sum(inverse_log_det_jac.values())
+
+ # forward is the same regardless
+ assert np.allclose(forward_log_det_jac["p"], log_det_jac["p"])
+
+ # inverse works when concatenation is used after transforms
+ assert np.allclose(forward_log_det_jac["p"], -inverse_log_det_jac)
diff --git a/tests/test_approximators/conftest.py b/tests/test_approximators/conftest.py
index 125371a52..227e70ff1 100644
--- a/tests/test_approximators/conftest.py
+++ b/tests/test_approximators/conftest.py
@@ -163,3 +163,34 @@ def validation_dataset(batch_size, adapter, simulator):
num_batches = 2
data = simulator.sample((num_batches * batch_size,))
return OfflineDataset(data=data, adapter=adapter, batch_size=batch_size, workers=4, max_queue_size=num_batches)
+
+
+@pytest.fixture()
+def mean_std_summary_network():
+ from tests.utils import MeanStdSummaryNetwork
+
+ return MeanStdSummaryNetwork()
+
+
+@pytest.fixture(params=["continuous_approximator", "point_approximator", "model_comparison_approximator"])
+def approximator_with_summaries(request):
+ from bayesflow.adapters import Adapter
+
+ adapter = Adapter()
+ match request.param:
+ case "continuous_approximator":
+ from bayesflow.approximators import ContinuousApproximator
+
+ return ContinuousApproximator(adapter=adapter, inference_network=None, summary_network=None)
+ case "point_approximator":
+ from bayesflow.approximators import PointApproximator
+
+ return PointApproximator(adapter=adapter, inference_network=None, summary_network=None)
+ case "model_comparison_approximator":
+ from bayesflow.approximators import ModelComparisonApproximator
+
+ return ModelComparisonApproximator(
+ num_models=2, classifier_network=None, adapter=adapter, summary_network=None
+ )
+ case _:
+ raise ValueError("Invalid param for approximator class.")
diff --git a/tests/test_approximators/test_summaries.py b/tests/test_approximators/test_summaries.py
new file mode 100644
index 000000000..7962ddaab
--- /dev/null
+++ b/tests/test_approximators/test_summaries.py
@@ -0,0 +1,23 @@
+import pytest
+from tests.utils import assert_allclose
+import keras
+
+
+def test_valid_summaries(approximator_with_summaries, mean_std_summary_network, monkeypatch):
+ monkeypatch.setattr(approximator_with_summaries, "summary_network", mean_std_summary_network)
+ summaries = approximator_with_summaries.summaries({"summary_variables": keras.ops.ones((2, 3))})
+ assert_allclose(summaries, keras.ops.stack([keras.ops.ones((2,)), keras.ops.zeros((2,))], axis=-1))
+
+
+def test_no_summary_network(approximator_with_summaries, monkeypatch):
+ monkeypatch.setattr(approximator_with_summaries, "summary_network", None)
+
+ with pytest.raises(ValueError):
+ approximator_with_summaries.summaries({"summary_variables": keras.ops.ones((2, 3))})
+
+
+def test_no_summary_variables(approximator_with_summaries, mean_std_summary_network, monkeypatch):
+ monkeypatch.setattr(approximator_with_summaries, "summary_network", mean_std_summary_network)
+
+ with pytest.raises(ValueError):
+ approximator_with_summaries.summaries({})
diff --git a/tests/test_diagnostics/conftest.py b/tests/test_diagnostics/conftest.py
index 8e77d6729..dc859d2d4 100644
--- a/tests/test_diagnostics/conftest.py
+++ b/tests/test_diagnostics/conftest.py
@@ -78,3 +78,17 @@ def history():
}
return h
+
+
+@pytest.fixture()
+def adapter():
+ from bayesflow.adapters import Adapter
+
+ return Adapter.create_default("parameters").rename("observables", "summary_variables")
+
+
+@pytest.fixture()
+def summary_network():
+ from tests.utils import MeanStdSummaryNetwork
+
+ return MeanStdSummaryNetwork()
diff --git a/tests/test_diagnostics/test_diagnostics_metrics.py b/tests/test_diagnostics/test_diagnostics_metrics.py
index 4fb0945b3..3a2c711bc 100644
--- a/tests/test_diagnostics/test_diagnostics_metrics.py
+++ b/tests/test_diagnostics/test_diagnostics_metrics.py
@@ -1,6 +1,9 @@
-import bayesflow as bf
+import numpy as np
+import keras
import pytest
+import bayesflow as bf
+
def num_variables(x: dict):
return sum(arr.shape[-1] for arr in x.values())
@@ -79,3 +82,288 @@ def test_expected_calibration_error(pred_models, true_models, model_names):
with pytest.raises(Exception):
out = bf.diagnostics.metrics.expected_calibration_error(pred_models, true_models.transpose)
+
+
+def test_bootstrap_comparison_shapes():
+ """Test the bootstrap_comparison output shapes."""
+ observed_samples = np.random.rand(10, 5)
+ reference_samples = np.random.rand(100, 5)
+ num_null_samples = 50
+
+ distance_observed, distance_null = bf.diagnostics.metrics.bootstrap_comparison(
+ observed_samples,
+ reference_samples,
+ lambda x, y: keras.ops.abs(keras.ops.mean(x) - keras.ops.mean(y)),
+ num_null_samples,
+ )
+
+ assert isinstance(distance_observed, float)
+ assert isinstance(distance_null, np.ndarray)
+ assert distance_null.shape == (num_null_samples,)
+
+
+def test_bootstrap_comparison_same_distribution():
+ """Test bootstrap_comparison on same distributions."""
+ observed_samples = np.random.normal(loc=0.5, scale=0.1, size=(10, 5))
+ reference_samples = observed_samples.copy()
+ num_null_samples = 5
+
+ distance_observed, distance_null = bf.diagnostics.metrics.bootstrap_comparison(
+ observed_samples,
+ reference_samples,
+ lambda x, y: keras.ops.abs(keras.ops.mean(x) - keras.ops.mean(y)),
+ num_null_samples,
+ )
+
+ assert distance_observed <= np.quantile(distance_null, 0.99)
+
+
+def test_bootstrap_comparison_different_distributions():
+ """Test bootstrap_comparison on different distributions."""
+ observed_samples = np.random.normal(loc=-5, scale=0.1, size=(10, 5))
+ reference_samples = np.random.normal(loc=5, scale=0.1, size=(100, 5))
+ num_null_samples = 50
+
+ distance_observed, distance_null = bf.diagnostics.metrics.bootstrap_comparison(
+ observed_samples,
+ reference_samples,
+ lambda x, y: keras.ops.abs(keras.ops.mean(x) - keras.ops.mean(y)),
+ num_null_samples,
+ )
+
+ assert distance_observed >= np.quantile(distance_null, 0.68)
+
+
+def test_bootstrap_comparison_mismatched_shapes():
+ """Test bootstrap_comparison raises ValueError for mismatched shapes."""
+ observed_samples = np.random.rand(10, 5)
+ reference_samples = np.random.rand(20, 4)
+ num_null_samples = 10
+
+ with pytest.raises(ValueError):
+ bf.diagnostics.metrics.bootstrap_comparison(
+ observed_samples,
+ reference_samples,
+ lambda x, y: keras.ops.abs(keras.ops.mean(x) - keras.ops.mean(y)),
+ num_null_samples,
+ )
+
+
+def test_bootstrap_comparison_num_observed_exceeds_num_reference():
+ """Test bootstrap_comparison raises ValueError when number of observed samples exceeds the number of reference
+ samples."""
+ observed_samples = np.random.rand(100, 5)
+ reference_samples = np.random.rand(20, 5)
+ num_null_samples = 50
+
+ with pytest.raises(ValueError):
+ bf.diagnostics.metrics.bootstrap_comparison(
+ observed_samples,
+ reference_samples,
+ lambda x, y: keras.ops.abs(keras.ops.mean(x) - keras.ops.mean(y)),
+ num_null_samples,
+ )
+
+
+def test_mmd_comparison_from_summaries_shapes():
+ """Test the mmd_comparison_from_summaries output shapes."""
+ observed_summaries = np.random.rand(10, 5)
+ reference_summaries = np.random.rand(100, 5)
+ num_null_samples = 50
+
+ mmd_observed, mmd_null = bf.diagnostics.metrics.bootstrap_comparison(
+ observed_summaries,
+ reference_summaries,
+ comparison_fn=bf.metrics.functional.maximum_mean_discrepancy,
+ num_null_samples=num_null_samples,
+ )
+
+ assert isinstance(mmd_observed, float)
+ assert isinstance(mmd_null, np.ndarray)
+ assert mmd_null.shape == (num_null_samples,)
+
+
+def test_mmd_comparison_from_summaries_positive():
+ """Test MMD output values of mmd_comparison_from_summaries are positive."""
+ observed_summaries = np.random.rand(10, 5)
+ reference_summaries = np.random.rand(100, 5)
+ num_null_samples = 50
+
+ mmd_observed, mmd_null = bf.diagnostics.metrics.bootstrap_comparison(
+ observed_summaries,
+ reference_summaries,
+ comparison_fn=bf.metrics.functional.maximum_mean_discrepancy,
+ num_null_samples=num_null_samples,
+ )
+
+ assert mmd_observed >= 0
+ assert np.all(mmd_null >= 0)
+
+
+def test_mmd_comparison_from_summaries_same_distribution():
+ """Test mmd_comparison_from_summaries on same distributions."""
+ observed_summaries = np.random.rand(10, 5)
+ reference_summaries = observed_summaries.copy()
+ num_null_samples = 5
+
+ mmd_observed, mmd_null = bf.diagnostics.metrics.bootstrap_comparison(
+ observed_summaries,
+ reference_summaries,
+ comparison_fn=bf.metrics.functional.maximum_mean_discrepancy,
+ num_null_samples=num_null_samples,
+ )
+
+ assert mmd_observed <= np.quantile(mmd_null, 0.99)
+
+
+def test_mmd_comparison_from_summaries_different_distributions():
+ """Test mmd_comparison_from_summaries on different distributions."""
+ observed_summaries = np.random.rand(10, 5)
+ reference_summaries = np.random.normal(loc=0.5, scale=0.1, size=(100, 5))
+ num_null_samples = 50
+
+ mmd_observed, mmd_null = bf.diagnostics.metrics.bootstrap_comparison(
+ observed_summaries,
+ reference_summaries,
+ comparison_fn=bf.metrics.functional.maximum_mean_discrepancy,
+ num_null_samples=num_null_samples,
+ )
+
+ assert mmd_observed >= np.quantile(mmd_null, 0.68)
+
+
+def test_mmd_comparison_shapes(summary_network, adapter):
+ """Test the mmd_comparison output shapes."""
+ observed_data = dict(observables=np.random.rand(10, 5))
+ reference_data = dict(observables=np.random.rand(100, 5))
+ num_null_samples = 50
+
+ mock_approximator = bf.approximators.ContinuousApproximator(
+ adapter=adapter,
+ inference_network=None,
+ summary_network=summary_network,
+ )
+
+ mmd_observed, mmd_null = bf.diagnostics.metrics.summary_space_comparison(
+ observed_data=observed_data,
+ reference_data=reference_data,
+ approximator=mock_approximator,
+ num_null_samples=num_null_samples,
+ comparison_fn=bf.metrics.functional.maximum_mean_discrepancy,
+ )
+
+ assert isinstance(mmd_observed, float)
+ assert isinstance(mmd_null, np.ndarray)
+ assert mmd_null.shape == (num_null_samples,)
+
+
+def test_mmd_comparison_positive(summary_network, adapter):
+ """Test MMD output values of mmd_comparison are positive."""
+ observed_data = dict(observables=np.random.rand(10, 5))
+ reference_data = dict(observables=np.random.rand(100, 5))
+ num_null_samples = 50
+
+ mock_approximator = bf.approximators.ContinuousApproximator(
+ adapter=adapter,
+ inference_network=None,
+ summary_network=summary_network,
+ )
+
+ mmd_observed, mmd_null = bf.diagnostics.metrics.summary_space_comparison(
+ observed_data=observed_data,
+ reference_data=reference_data,
+ approximator=mock_approximator,
+ num_null_samples=num_null_samples,
+ comparison_fn=bf.metrics.functional.maximum_mean_discrepancy,
+ )
+
+ assert mmd_observed >= 0
+ assert np.all(mmd_null >= 0)
+
+
+def test_mmd_comparison_same_distribution(summary_network, adapter):
+ """Test mmd_comparison on same distributions."""
+ observed_data = dict(observables=np.random.rand(10, 5))
+ reference_data = observed_data
+ num_null_samples = 5
+
+ mock_approximator = bf.approximators.ContinuousApproximator(
+ adapter=adapter,
+ inference_network=None,
+ summary_network=summary_network,
+ )
+
+ mmd_observed, mmd_null = bf.diagnostics.metrics.summary_space_comparison(
+ observed_data=observed_data,
+ reference_data=reference_data,
+ approximator=mock_approximator,
+ num_null_samples=num_null_samples,
+ comparison_fn=bf.metrics.functional.maximum_mean_discrepancy,
+ )
+
+ assert mmd_observed <= np.quantile(mmd_null, 0.99)
+
+
+def test_mmd_comparison_different_distributions(summary_network, adapter):
+ """Test mmd_comparison on different distributions."""
+ observed_data = dict(observables=np.random.rand(10, 5))
+ reference_data = dict(observables=np.random.normal(loc=0.5, scale=0.1, size=(100, 5)))
+ num_null_samples = 50
+
+ mock_approximator = bf.approximators.ContinuousApproximator(
+ adapter=adapter,
+ inference_network=None,
+ summary_network=summary_network,
+ )
+
+ mmd_observed, mmd_null = bf.diagnostics.metrics.summary_space_comparison(
+ observed_data=observed_data,
+ reference_data=reference_data,
+ approximator=mock_approximator,
+ num_null_samples=num_null_samples,
+ comparison_fn=bf.metrics.functional.maximum_mean_discrepancy,
+ )
+
+ assert mmd_observed >= np.quantile(mmd_null, 0.68)
+
+
+def test_mmd_comparison_no_summary_network(adapter):
+ observed_data = dict(observables=np.random.rand(10, 5))
+ reference_data = dict(observables=np.random.rand(100, 5))
+ num_null_samples = 50
+
+ mock_approximator = bf.approximators.ContinuousApproximator(
+ adapter=adapter,
+ inference_network=None,
+ summary_network=None,
+ )
+
+ with pytest.raises(ValueError):
+ bf.diagnostics.metrics.summary_space_comparison(
+ observed_data=observed_data,
+ reference_data=reference_data,
+ approximator=mock_approximator,
+ num_null_samples=num_null_samples,
+ comparison_fn=bf.metrics.functional.maximum_mean_discrepancy,
+ )
+
+
+def test_mmd_comparison_approximator_incorrect_instance():
+ """Test mmd_comparison raises ValueError for incorrect approximator instance."""
+ observed_data = dict(observables=np.random.rand(10, 5))
+ reference_data = dict(observables=np.random.rand(100, 5))
+ num_null_samples = 50
+
+ class IncorrectApproximator:
+ pass
+
+ mock_approximator = IncorrectApproximator()
+
+ with pytest.raises(ValueError):
+ bf.diagnostics.metrics.summary_space_comparison(
+ observed_data=observed_data,
+ reference_data=reference_data,
+ approximator=mock_approximator,
+ num_null_samples=num_null_samples,
+ comparison_fn=bf.metrics.functional.maximum_mean_discrepancy,
+ )
diff --git a/tests/test_utils/test_serialize_deserialize.py b/tests/test_utils/test_serialize_deserialize.py
index 6c3bc3983..a9888ecc5 100644
--- a/tests/test_utils/test_serialize_deserialize.py
+++ b/tests/test_utils/test_serialize_deserialize.py
@@ -3,7 +3,7 @@
from bayesflow.utils.serialization import deserialize, serializable, serialize
-@serializable
+@serializable("test", disable_module_check=True)
class Foo:
@classmethod
def from_config(cls, config, custom_objects=None):
@@ -13,7 +13,7 @@ def get_config(self):
return {}
-@serializable
+@serializable("test", disable_module_check=True)
class Bar:
@classmethod
def from_config(cls, config, custom_objects=None):
diff --git a/tests/test_workflows/conftest.py b/tests/test_workflows/conftest.py
index c98e543e9..e66b6efe4 100644
--- a/tests/test_workflows/conftest.py
+++ b/tests/test_workflows/conftest.py
@@ -40,7 +40,7 @@ def summary_network(request):
elif request.param == "custom":
from bayesflow.networks import SummaryNetwork
- @serializable
+ @serializable("test", disable_module_check=True)
class Custom(SummaryNetwork):
def __init__(self, **kwargs):
super().__init__(**kwargs)
diff --git a/tests/utils/__init__.py b/tests/utils/__init__.py
index f36b02bbd..9c2affc22 100644
--- a/tests/utils/__init__.py
+++ b/tests/utils/__init__.py
@@ -2,4 +2,5 @@
from .callbacks import *
from .check_combinations import *
from .jupyter import *
+from .networks import *
from .ops import *
diff --git a/tests/utils/networks.py b/tests/utils/networks.py
new file mode 100644
index 000000000..cf35e1463
--- /dev/null
+++ b/tests/utils/networks.py
@@ -0,0 +1,8 @@
+from bayesflow.networks import SummaryNetwork
+import keras
+
+
+class MeanStdSummaryNetwork(SummaryNetwork):
+ def call(self, x):
+ summary_outputs = keras.ops.stack([keras.ops.mean(x, axis=-1), keras.ops.std(x, axis=-1)], axis=-1)
+ return summary_outputs