Skip to content

Commit c58bb50

Browse files
committed
Merge remote-tracking branch 'upstream/dev' into fix-additional-metrics
2 parents f536064 + 8afff13 commit c58bb50

File tree

63 files changed

+1823
-1188
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

63 files changed

+1823
-1188
lines changed

README.md

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -135,7 +135,8 @@ Many examples from [Bayesian Cognitive Modeling: A Practical Course](https://bay
135135
6. [Bayesian experimental design](examples/Bayesian_Experimental_Design.ipynb)
136136
7. [Simple model comparison example](examples/One_Sample_TTest.ipynb)
137137
8. [Likelihood estimation](examples/Likelihood_Estimation.ipynb)
138-
9. [Moving from BayesFlow v1.1 to v2.0](examples/From_BayesFlow_1.1_to_2.0.ipynb)
138+
9. [Multimodal data](examples/Multimodal_Data.ipynb)
139+
10. [Moving from BayesFlow v1.1 to v2.0](examples/From_BayesFlow_1.1_to_2.0.ipynb)
139140

140141
More tutorials are always welcome! Please consider making a pull request if you have a cool application that you want to contribute.
141142

bayesflow/__init__.py

Lines changed: 78 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -1,26 +1,64 @@
1-
from . import (
2-
approximators,
3-
adapters,
4-
datasets,
5-
diagnostics,
6-
distributions,
7-
experimental,
8-
networks,
9-
simulators,
10-
utils,
11-
workflows,
12-
wrappers,
13-
)
14-
15-
from .adapters import Adapter
16-
from .approximators import ContinuousApproximator, PointApproximator
17-
from .datasets import OfflineDataset, OnlineDataset, DiskDataset
18-
from .simulators import make_simulator
19-
from .workflows import BasicWorkflow
1+
# ruff: noqa: E402
2+
# disable E402 to allow for setup code before importing any internals (which could import keras)
203

214

225
def setup():
236
# perform any necessary setup without polluting the namespace
7+
import os
8+
from importlib.util import find_spec
9+
10+
issue_url = "https://github.com/bayesflow-org/bayesflow/issues/new?template=bug_report.md"
11+
12+
if "KERAS_BACKEND" not in os.environ:
13+
# check for available backends and automatically set the KERAS_BACKEND env variable or raise an error
14+
class Backend:
15+
def __init__(self, display_name, package_name, env_name, install_url, priority):
16+
self.display_name = display_name
17+
self.package_name = package_name
18+
self.env_name = env_name
19+
self.install_url = install_url
20+
self.priority = priority
21+
22+
backends = [
23+
Backend("JAX", "jax", "jax", "https://docs.jax.dev/en/latest/quickstart.html#installation", 0),
24+
Backend("PyTorch", "torch", "torch", "https://pytorch.org/get-started/locally/", 1),
25+
Backend("TensorFlow", "tensorflow", "tensorflow", "https://www.tensorflow.org/install", 2),
26+
]
27+
28+
found_backends = []
29+
for backend in backends:
30+
if find_spec(backend.package_name) is not None:
31+
found_backends.append(backend)
32+
33+
if not found_backends:
34+
message = "No suitable backend found. Please install one of the following:\n"
35+
for backend in backends:
36+
message += f"{backend.display_name}\n"
37+
message += "\n"
38+
39+
message += f"If you continue to see this error, please file a bug report at {issue_url}.\n"
40+
message += (
41+
"You can manually select a backend by setting the KERAS_BACKEND environment variable as shown below:\n"
42+
)
43+
message += "https://keras.io/getting_started/#configuring-your-backend"
44+
45+
raise ImportError(message)
46+
47+
if len(found_backends) > 1:
48+
import warnings
49+
50+
found_backends.sort(key=lambda b: b.priority)
51+
chosen_backend = found_backends[0]
52+
53+
warnings.warn(
54+
f"Multiple Keras-compatible backends detected ({', '.join(b.display_name for b in found_backends)}).\n"
55+
f"Defaulting to {chosen_backend.display_name}.\n"
56+
"To override, set the KERAS_BACKEND environment variable before importing bayesflow.\n"
57+
"See: https://keras.io/getting_started/#configuring-your-backend"
58+
)
59+
else:
60+
os.environ["KERAS_BACKEND"] = found_backends[0].env_name
61+
2462
import keras
2563
import logging
2664

@@ -59,3 +97,24 @@ def setup():
5997
# call and clean up namespace
6098
setup()
6199
del setup
100+
101+
from . import (
102+
approximators,
103+
adapters,
104+
augmentations,
105+
datasets,
106+
diagnostics,
107+
distributions,
108+
experimental,
109+
networks,
110+
simulators,
111+
utils,
112+
workflows,
113+
wrappers,
114+
)
115+
116+
from .adapters import Adapter
117+
from .approximators import ContinuousApproximator, PointApproximator
118+
from .datasets import OfflineDataset, OnlineDataset, DiskDataset
119+
from .simulators import make_simulator
120+
from .workflows import BasicWorkflow

bayesflow/adapters/adapter.py

Lines changed: 22 additions & 58 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
from collections.abc import Callable, MutableSequence, Sequence, Mapping
1+
from collections.abc import Callable, MutableSequence, Sequence
22

33
import numpy as np
44

@@ -18,7 +18,6 @@
1818
Keep,
1919
Log,
2020
MapTransform,
21-
NNPE,
2221
NumpyTransform,
2322
OneHot,
2423
Rename,
@@ -87,16 +86,14 @@ def get_config(self) -> dict:
8786
return serialize(config)
8887

8988
def forward(
90-
self, data: dict[str, any], *, stage: str = "inference", log_det_jac: bool = False, **kwargs
89+
self, data: dict[str, any], *, log_det_jac: bool = False, **kwargs
9190
) -> dict[str, np.ndarray] | tuple[dict[str, np.ndarray], dict[str, np.ndarray]]:
9291
"""Apply the transforms in the forward direction.
9392
9493
Parameters
9594
----------
96-
data : dict
95+
data : dict[str, any]
9796
The data to be transformed.
98-
stage : str, one of ["training", "validation", "inference"]
99-
The stage the function is called in.
10097
log_det_jac: bool, optional
10198
Whether to return the log determinant of the Jacobian of the transforms.
10299
**kwargs : dict
@@ -110,28 +107,26 @@ def forward(
110107
data = data.copy()
111108
if not log_det_jac:
112109
for transform in self.transforms:
113-
data = transform(data, stage=stage, **kwargs)
110+
data = transform(data, **kwargs)
114111
return data
115112

116113
log_det_jac = {}
117114
for transform in self.transforms:
118-
transformed_data = transform(data, stage=stage, **kwargs)
115+
transformed_data = transform(data, **kwargs)
119116
log_det_jac = transform.log_det_jac(data, log_det_jac, **kwargs)
120117
data = transformed_data
121118

122119
return data, log_det_jac
123120

124121
def inverse(
125-
self, data: dict[str, np.ndarray], *, stage: str = "inference", log_det_jac: bool = False, **kwargs
122+
self, data: dict[str, any], *, log_det_jac: bool = False, **kwargs
126123
) -> dict[str, np.ndarray] | tuple[dict[str, np.ndarray], dict[str, np.ndarray]]:
127124
"""Apply the transforms in the inverse direction.
128125
129126
Parameters
130127
----------
131-
data : dict
128+
data : dict[str, any]
132129
The data to be transformed.
133-
stage : str, one of ["training", "validation", "inference"]
134-
The stage the function is called in.
135130
log_det_jac: bool, optional
136131
Whether to return the log determinant of the Jacobian of the transforms.
137132
**kwargs : dict
@@ -145,18 +140,18 @@ def inverse(
145140
data = data.copy()
146141
if not log_det_jac:
147142
for transform in reversed(self.transforms):
148-
data = transform(data, stage=stage, inverse=True, **kwargs)
143+
data = transform(data, inverse=True, **kwargs)
149144
return data
150145

151146
log_det_jac = {}
152147
for transform in reversed(self.transforms):
153-
data = transform(data, stage=stage, inverse=True, **kwargs)
148+
data = transform(data, inverse=True, **kwargs)
154149
log_det_jac = transform.log_det_jac(data, log_det_jac, inverse=True, **kwargs)
155150

156151
return data, log_det_jac
157152

158153
def __call__(
159-
self, data: Mapping[str, any], *, inverse: bool = False, stage="inference", **kwargs
154+
self, data: dict[str, any], *, inverse: bool = False, **kwargs
160155
) -> dict[str, np.ndarray] | tuple[dict[str, np.ndarray], dict[str, np.ndarray]]:
161156
"""Apply the transforms in the given direction.
162157
@@ -166,8 +161,6 @@ def __call__(
166161
The data to be transformed.
167162
inverse : bool, optional
168163
If False, apply the forward transform, else apply the inverse transform (default False).
169-
stage : str, one of ["training", "validation", "inference"]
170-
The stage the function is called in.
171164
**kwargs
172165
Additional keyword arguments passed to each transform.
173166
@@ -177,9 +170,9 @@ def __call__(
177170
The transformed data or tuple of transformed data and log determinant of the Jacobian.
178171
"""
179172
if inverse:
180-
return self.inverse(data, stage=stage, **kwargs)
173+
return self.inverse(data, **kwargs)
181174

182-
return self.forward(data, stage=stage, **kwargs)
175+
return self.forward(data, **kwargs)
183176

184177
def __repr__(self):
185178
result = ""
@@ -701,43 +694,6 @@ def map_dtype(self, keys: str | Sequence[str], to_dtype: str):
701694
self.transforms.append(transform)
702695
return self
703696

704-
def nnpe(
705-
self,
706-
keys: str | Sequence[str],
707-
*,
708-
spike_scale: float | None = None,
709-
slab_scale: float | None = None,
710-
per_dimension: bool = True,
711-
seed: int | None = None,
712-
):
713-
"""Append an :py:class:`~transforms.NNPE` transform to the adapter.
714-
715-
Parameters
716-
----------
717-
keys : str or Sequence of str
718-
The names of the variables to transform.
719-
spike_scale : float or np.ndarray or None, default=None
720-
The scale of the spike (Normal) distribution. Automatically determined if None.
721-
slab_scale : float or np.ndarray or None, default=None
722-
The scale of the slab (Cauchy) distribution. Automatically determined if None.
723-
per_dimension : bool, default=True
724-
If true, noise is applied per dimension of the last axis of the input data.
725-
If false, noise is applied globally.
726-
seed : int or None
727-
The seed for the random number generator. If None, a random seed is used.
728-
"""
729-
if isinstance(keys, str):
730-
keys = [keys]
731-
732-
transform = MapTransform(
733-
{
734-
key: NNPE(spike_scale=spike_scale, slab_scale=slab_scale, per_dimension=per_dimension, seed=seed)
735-
for key in keys
736-
}
737-
)
738-
self.transforms.append(transform)
739-
return self
740-
741697
def one_hot(self, keys: str | Sequence[str], num_classes: int):
742698
"""Append a :py:class:`~transforms.OneHot` transform to the adapter.
743699
@@ -857,6 +813,8 @@ def standardize(
857813
self,
858814
include: str | Sequence[str] = None,
859815
*,
816+
mean: int | float | np.ndarray,
817+
std: int | float | np.ndarray,
860818
predicate: Predicate = None,
861819
exclude: str | Sequence[str] = None,
862820
**kwargs,
@@ -865,10 +823,14 @@ def standardize(
865823
866824
Parameters
867825
----------
868-
predicate : Predicate, optional
869-
Function that indicates which variables should be transformed.
870826
include : str or Sequence of str, optional
871827
Names of variables to include in the transform.
828+
mean : int or float
829+
Specifies the mean (location) of the transform.
830+
std : int or float
831+
Specifies the standard deviation (scale) of the transform.
832+
predicate : Predicate, optional
833+
Function that indicates which variables should be transformed.
872834
exclude : str or Sequence of str, optional
873835
Names of variables to exclude from the transform.
874836
**kwargs :
@@ -879,6 +841,8 @@ def standardize(
879841
predicate=predicate,
880842
include=include,
881843
exclude=exclude,
844+
mean=mean,
845+
std=std,
882846
**kwargs,
883847
)
884848
self.transforms.append(transform)

bayesflow/adapters/transforms/__init__.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,6 @@
1212
from .keep import Keep
1313
from .log import Log
1414
from .map_transform import MapTransform
15-
from .nnpe import NNPE
1615
from .numpy_transform import NumpyTransform
1716
from .one_hot import OneHot
1817
from .rename import Rename

bayesflow/adapters/transforms/broadcast.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -117,7 +117,7 @@ def forward(self, data: dict[str, np.ndarray], **kwargs) -> dict[str, np.ndarray
117117
data[k] = np.expand_dims(data[k], axis=tuple(np.arange(0, len_diff)))
118118
elif self.expand == "right":
119119
data[k] = np.expand_dims(data[k], axis=tuple(-np.arange(1, len_diff + 1)))
120-
elif isinstance(self.expand, tuple):
120+
elif isinstance(self.expand, Sequence):
121121
if len(self.expand) is not len_diff:
122122
raise ValueError("Length of `expand` must match the length difference of the involed arrays.")
123123
data[k] = np.expand_dims(data[k], axis=self.expand)

0 commit comments

Comments
 (0)