Skip to content

Commit 6572f06

Browse files
committed
Merge branch 'dev' into feat-diffusion-model
2 parents 78814ac + b4d0a72 commit 6572f06

File tree

197 files changed

+1654
-590
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

197 files changed

+1654
-590
lines changed

.gitignore

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -39,3 +39,6 @@ docs/
3939

4040
# MacOS
4141
.DS_Store
42+
43+
# Rproj
44+
.Rproj.user

README.md

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,12 @@ It provides users and researchers with:
1515
BayesFlow (version 2+) is designed to be a flexible and efficient tool that enables rapid statistical inference
1616
fueled by continuous progress in generative AI and Bayesian inference.
1717

18+
> [!IMPORTANT]
19+
> As the 2.0 version introduced many new features, we still have to make breaking changes from time to time.
20+
> This especially concerns **saving and loading** of models. We aim to stabilize this from the 2.1 release onwards.
21+
> Until then, consider pinning your BayesFlow 2.0 installation to an exact version, or re-training after an update
22+
> for less costly models.
23+
1824
## Important Note for Existing Users
1925

2026
You are currently looking at BayesFlow 2.0+, which is a complete rewrite of the library.
@@ -245,8 +251,9 @@ Depending on your needs, you might not want to upgrade yet if one of the followi
245251
with the new version. Loading models from version 1.x in version 2.0+ is not supported.
246252
- You require a feature that was not ported to BayesFlow 2.0+ yet. To our knowledge,
247253
this applies to:
248-
* Two-level/Hierarchical models: `TwoLevelGenerativeModel`, `TwoLevelPrior`.
249-
* Sensitivity analysis: functionality from the `bayesflow.sensitivity` module.
254+
* Two-level/Hierarchical models (planned for version 2.1): `TwoLevelGenerativeModel`, `TwoLevelPrior`.
255+
* Sensitivity analysis (partially discontinued): functionality from the `bayesflow.sensitivity` module. This is still
256+
possible, but we do no longer offer a special module for it. We plan to add a tutorial on this, see [#455](https://github.com/bayesflow-org/bayesflow/issues/455).
250257
* MCMC (discontinued): The `bayesflow.mcmc` module. We are considering other options
251258
to enable the use of BayesFlow in an MCMC setting.
252259
* Networks: `EvidentialNetwork`.

bayesflow/__init__.py

Lines changed: 3 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -50,17 +50,11 @@ def setup():
5050
"in contexts where you need gradients (e.g. custom training loops)."
5151
)
5252

53+
# dynamically add __version__ attribute
54+
from importlib.metadata import version
5355

54-
# dynamically add version dunder variable
55-
try:
56-
from importlib.metadata import version, PackageNotFoundError
56+
globals()["__version__"] = version("bayesflow")
5757

58-
__version__ = version(__name__)
59-
except PackageNotFoundError:
60-
__version__ = "2.0.0"
61-
finally:
62-
del version
63-
del PackageNotFoundError
6458

6559
# call and clean up namespace
6660
setup()

bayesflow/adapters/adapter.py

Lines changed: 95 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -25,11 +25,13 @@
2525
Standardize,
2626
ToArray,
2727
Transform,
28+
RandomSubsample,
29+
Take,
2830
)
2931
from .transforms.filter_transform import Predicate
3032

3133

32-
@serializable
34+
@serializable("bayesflow.adapters")
3335
class Adapter(MutableSequence[Transform]):
3436
"""
3537
Defines an adapter to apply various transforms to data.
@@ -79,7 +81,9 @@ def get_config(self) -> dict:
7981

8082
return serialize(config)
8183

82-
def forward(self, data: dict[str, any], *, stage: str = "inference", **kwargs) -> dict[str, np.ndarray]:
84+
def forward(
85+
self, data: dict[str, any], *, stage: str = "inference", log_det_jac: bool = False, **kwargs
86+
) -> dict[str, np.ndarray] | tuple[dict[str, np.ndarray], dict[str, np.ndarray]]:
8387
"""Apply the transforms in the forward direction.
8488
8589
Parameters
@@ -88,22 +92,33 @@ def forward(self, data: dict[str, any], *, stage: str = "inference", **kwargs) -
8892
The data to be transformed.
8993
stage : str, one of ["training", "validation", "inference"]
9094
The stage the function is called in.
95+
log_det_jac: bool, optional
96+
Whether to return the log determinant of the Jacobian of the transforms.
9197
**kwargs : dict
9298
Additional keyword arguments passed to each transform.
9399
94100
Returns
95101
-------
96-
dict
97-
The transformed data.
102+
dict | tuple[dict, dict]
103+
The transformed data or tuple of transformed data and log determinant of the Jacobian.
98104
"""
99105
data = data.copy()
106+
if not log_det_jac:
107+
for transform in self.transforms:
108+
data = transform(data, stage=stage, **kwargs)
109+
return data
100110

111+
log_det_jac = {}
101112
for transform in self.transforms:
102-
data = transform(data, stage=stage, **kwargs)
113+
transformed_data = transform(data, stage=stage, **kwargs)
114+
log_det_jac = transform.log_det_jac(data, log_det_jac, **kwargs)
115+
data = transformed_data
103116

104-
return data
117+
return data, log_det_jac
105118

106-
def inverse(self, data: dict[str, np.ndarray], *, stage: str = "inference", **kwargs) -> dict[str, any]:
119+
def inverse(
120+
self, data: dict[str, np.ndarray], *, stage: str = "inference", log_det_jac: bool = False, **kwargs
121+
) -> dict[str, np.ndarray] | tuple[dict[str, np.ndarray], dict[str, np.ndarray]]:
107122
"""Apply the transforms in the inverse direction.
108123
109124
Parameters
@@ -112,24 +127,32 @@ def inverse(self, data: dict[str, np.ndarray], *, stage: str = "inference", **kw
112127
The data to be transformed.
113128
stage : str, one of ["training", "validation", "inference"]
114129
The stage the function is called in.
130+
log_det_jac: bool, optional
131+
Whether to return the log determinant of the Jacobian of the transforms.
115132
**kwargs : dict
116133
Additional keyword arguments passed to each transform.
117134
118135
Returns
119136
-------
120-
dict
121-
The transformed data.
137+
dict | tuple[dict, dict]
138+
The transformed data or tuple of transformed data and log determinant of the Jacobian.
122139
"""
123140
data = data.copy()
141+
if not log_det_jac:
142+
for transform in reversed(self.transforms):
143+
data = transform(data, stage=stage, inverse=True, **kwargs)
144+
return data
124145

146+
log_det_jac = {}
125147
for transform in reversed(self.transforms):
126148
data = transform(data, stage=stage, inverse=True, **kwargs)
149+
log_det_jac = transform.log_det_jac(data, log_det_jac, inverse=True, **kwargs)
127150

128-
return data
151+
return data, log_det_jac
129152

130153
def __call__(
131154
self, data: Mapping[str, any], *, inverse: bool = False, stage="inference", **kwargs
132-
) -> dict[str, np.ndarray]:
155+
) -> dict[str, np.ndarray] | tuple[dict[str, np.ndarray], dict[str, np.ndarray]]:
133156
"""Apply the transforms in the given direction.
134157
135158
Parameters
@@ -145,8 +168,8 @@ def __call__(
145168
146169
Returns
147170
-------
148-
dict
149-
The transformed data.
171+
dict | tuple[dict, dict]
172+
The transformed data or tuple of transformed data and log determinant of the Jacobian.
150173
"""
151174
if inverse:
152175
return self.inverse(data, stage=stage, **kwargs)
@@ -644,6 +667,28 @@ def one_hot(self, keys: str | Sequence[str], num_classes: int):
644667
self.transforms.append(transform)
645668
return self
646669

670+
def random_subsample(self, key: str, *, sample_size: int | float, axis: int = -1):
671+
"""
672+
Append a :py:class:`~transforms.RandomSubsample` transform to the adapter.
673+
674+
Parameters
675+
----------
676+
key : str or Sequence of str
677+
The name of the variable to subsample.
678+
sample_size : int or float
679+
The number of samples to draw, or a fraction between 0 and 1 of the total number of samples to draw.
680+
axis: int, optional
681+
Which axis to draw samples over. The last axis is used by default.
682+
"""
683+
684+
if not isinstance(key, str):
685+
raise TypeError("Can only subsample one batch entry at a time.")
686+
687+
transform = MapTransform({key: RandomSubsample(sample_size=sample_size, axis=axis)})
688+
689+
self.transforms.append(transform)
690+
return self
691+
647692
def rename(self, from_key: str, to_key: str):
648693
"""Append a :py:class:`~transforms.Rename` transform to the adapter.
649694
@@ -720,7 +765,7 @@ def standardize(
720765
Names of variables to include in the transform.
721766
exclude : str or Sequence of str, optional
722767
Names of variables to exclude from the transform.
723-
**kwargs : dict
768+
**kwargs :
724769
Additional keyword arguments passed to the transform.
725770
"""
726771
transform = FilterTransform(
@@ -733,6 +778,42 @@ def standardize(
733778
self.transforms.append(transform)
734779
return self
735780

781+
def take(
782+
self,
783+
include: str | Sequence[str] = None,
784+
*,
785+
indices: Sequence[int],
786+
axis: int = -1,
787+
predicate: Predicate = None,
788+
exclude: str | Sequence[str] = None,
789+
):
790+
"""
791+
Append a :py:class:`~transforms.Take` transform to the adapter.
792+
793+
Parameters
794+
----------
795+
include : str or Sequence of str, optional
796+
Names of variables to include in the transform.
797+
indices : Sequence of int
798+
Which indices to take from the data.
799+
axis : int, optional
800+
Which axis to take from. The last axis is used by default.
801+
predicate : Predicate, optional
802+
Function that indicates which variables should be transformed.
803+
exclude : str or Sequence of str, optional
804+
Names of variables to exclude from the transform.
805+
"""
806+
transform = FilterTransform(
807+
transform_constructor=Take,
808+
predicate=predicate,
809+
include=include,
810+
exclude=exclude,
811+
indices=indices,
812+
axis=axis,
813+
)
814+
self.transforms.append(transform)
815+
return self
816+
736817
def to_array(
737818
self,
738819
include: str | Sequence[str] = None,

bayesflow/adapters/transforms/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,8 @@
2323
from .to_array import ToArray
2424
from .to_dict import ToDict
2525
from .transform import Transform
26+
from .random_subsample import RandomSubsample
27+
from .take import Take
2628

2729
from ...utils._docs import _add_imports_to_all
2830

bayesflow/adapters/transforms/as_set.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
from .elementwise_transform import ElementwiseTransform
66

77

8-
@serializable
8+
@serializable("bayesflow.adapters")
99
class AsSet(ElementwiseTransform):
1010
"""The `.as_set(["x", "y"])` transform indicates that both `x` and `y` are treated as sets.
1111

bayesflow/adapters/transforms/as_time_series.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
from .elementwise_transform import ElementwiseTransform
66

77

8-
@serializable
8+
@serializable("bayesflow.adapters")
99
class AsTimeSeries(ElementwiseTransform):
1010
"""The `.as_time_series` transform can be used to indicate that variables shall be treated as time series.
1111

bayesflow/adapters/transforms/broadcast.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
from .transform import Transform
77

88

9-
@serializable
9+
@serializable("bayesflow.adapters")
1010
class Broadcast(Transform):
1111
"""
1212
Broadcasts arrays or scalars to the shape of a given other array.

bayesflow/adapters/transforms/concatenate.py

Lines changed: 35 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
from .transform import Transform
88

99

10-
@serializable
10+
@serializable("bayesflow.adapters")
1111
class Concatenate(Transform):
1212
"""Concatenate multiple arrays into a new key. Used to specify how data variables should be treated by the network.
1313
@@ -115,3 +115,37 @@ def extra_repr(self) -> str:
115115
result += f", axis={self.axis}"
116116

117117
return result
118+
119+
def log_det_jac(
120+
self,
121+
data: dict[str, np.ndarray],
122+
log_det_jac: dict[str, np.ndarray],
123+
*,
124+
strict: bool = False,
125+
inverse: bool = False,
126+
**kwargs,
127+
) -> dict[str, np.ndarray]:
128+
# copy to avoid side effects
129+
log_det_jac = log_det_jac.copy()
130+
131+
if inverse:
132+
if log_det_jac.get(self.into) is not None:
133+
raise ValueError(
134+
"Cannot obtain an inverse Jacobian of concatenation. "
135+
"Transform your variables before you concatenate."
136+
)
137+
138+
return log_det_jac
139+
140+
required_keys = set(self.keys)
141+
available_keys = set(log_det_jac.keys())
142+
common_keys = available_keys & required_keys
143+
144+
if len(common_keys) == 0:
145+
return log_det_jac
146+
147+
parts = [log_det_jac.pop(key) for key in common_keys]
148+
149+
log_det_jac[self.into] = sum(parts)
150+
151+
return log_det_jac

0 commit comments

Comments
 (0)