Skip to content

Commit 586d69f

Browse files
committed
Merge branch 'fix-docs-link' into updated-v1-to-v2-notebook
2 parents 68fff99 + b66b18e commit 586d69f

File tree

94 files changed

+2833
-786
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

94 files changed

+2833
-786
lines changed

.github/workflows/tests.yaml

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -7,12 +7,10 @@ on:
77
branches:
88
- main
99
- dev
10-
- update-workflows
1110
push:
1211
branches:
1312
- main
1413
- dev
15-
- update-workflows
1614

1715
defaults:
1816
run:
@@ -72,7 +70,13 @@ jobs:
7270
7371
- name: Run Tests
7472
run: |
75-
pytest -x
73+
pytest -x -m "not slow"
74+
75+
- name: Run Slow Tests
76+
# run all slow tests only on manual trigger
77+
if: github.event_name == 'workflow_dispatch'
78+
run: |
79+
pytest -m "slow"
7680
7781
- name: Upload test results to Codecov
7882
if: ${{ !cancelled() }}

README.md

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -37,14 +37,14 @@ Using the high-level interface is easy, as demonstrated by the minimal working e
3737
import bayesflow as bf
3838

3939
workflow = bf.BasicWorkflow(
40-
inference_network=bf.networks.FlowMatching(),
41-
summary_network=bf.networks.TimeSeriesTransformer(),
40+
inference_network=bf.networks.CouplingFlow(),
41+
summary_network=bf.networks.TimeSeriesNetwork(),
4242
inference_variables=["parameters"],
4343
summary_variables=["observables"],
4444
simulator=bf.simulators.SIR()
4545
)
4646

47-
history = workflow.fit_online(epochs=50, batch_size=32, num_batches_per_epoch=500)
47+
history = workflow.fit_online(epochs=15, batch_size=32, num_batches_per_epoch=200)
4848

4949
diagnostics = workflow.plot_default_diagnostics(test_data=300)
5050
```
@@ -54,7 +54,7 @@ For an in-depth exposition, check out our walkthrough notebooks below.
5454
1. [Linear regression starter example](examples/Linear_Regression_Starter.ipynb)
5555
2. [From ABC to BayesFlow](examples/From_ABC_to_BayesFlow.ipynb)
5656
3. [Two moons starter example](examples/Two_Moons_Starter.ipynb)
57-
4. [Rapid iteration with point estimators](examples/Lotka_Volterra_point_estimation_and_expert_stats.ipynb)
57+
4. [Rapid iteration with point estimators](examples/Lotka_Volterra_Point_Estimation_and_Expert_Stats.ipynb)
5858
5. [SIR model with custom summary network](examples/SIR_Posterior_Estimation.ipynb)
5959
6. [Bayesian experimental design](examples/Bayesian_Experimental_Design.ipynb)
6060
7. [Simple model comparison example](examples/One_Sample_TTest.ipynb)

bayesflow/__init__.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,8 +7,9 @@
77
experimental,
88
networks,
99
simulators,
10-
workflows,
1110
utils,
11+
workflows,
12+
wrappers,
1213
)
1314

1415
from .adapters import Adapter
@@ -40,8 +41,12 @@ def setup():
4041
torch.autograd.set_grad_enabled(False)
4142

4243
logging.warning(
44+
"\n"
4345
"When using torch backend, we need to disable autograd by default to avoid excessive memory usage. Use\n"
46+
"\n"
4447
"with torch.enable_grad():\n"
48+
" ...\n"
49+
"\n"
4550
"in contexts where you need gradients (e.g. custom training loops)."
4651
)
4752

bayesflow/adapters/adapter.py

Lines changed: 52 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
1-
from collections.abc import MutableSequence, Sequence
1+
from collections.abc import MutableSequence, Sequence, Mapping
22

33
import numpy as np
4+
45
from keras.saving import (
56
deserialize_keras_object as deserialize,
67
register_keras_serializable as serializable,
@@ -121,16 +122,16 @@ def inverse(self, data: dict[str, np.ndarray], **kwargs) -> dict[str, any]:
121122

122123
return data
123124

124-
def __call__(self, data: dict[str, any], *, inverse: bool = False, **kwargs) -> dict[str, np.ndarray]:
125+
def __call__(self, data: Mapping[str, any], *, inverse: bool = False, **kwargs) -> dict[str, np.ndarray]:
125126
"""Apply the transforms in the given direction.
126127
127128
Parameters
128129
----------
129-
data : dict
130+
data : Mapping[str, any]
130131
The data to be transformed.
131132
inverse : bool, optional
132133
If False, apply the forward transform, else apply the inverse transform (default False).
133-
**kwargs : dict
134+
**kwargs
134135
Additional keyword arguments passed to each transform.
135136
136137
Returns
@@ -233,28 +234,25 @@ def __len__(self):
233234

234235
def apply(
235236
self,
237+
include: str | Sequence[str] = None,
236238
*,
237239
forward: np.ufunc | str,
238240
inverse: np.ufunc | str = None,
239241
predicate: Predicate = None,
240-
include: str | Sequence[str] = None,
241242
exclude: str | Sequence[str] = None,
242243
**kwargs,
243244
):
244245
"""Append a :py:class:`~transforms.NumpyTransform` to the adapter.
245246
246247
Parameters
247248
----------
248-
forward: callable, no lambda
249-
Function to transform the data in the forward pass.
250-
For the adapter to be serializable, this function has to be serializable
251-
as well (see Notes). Therefore, only proper functions and no lambda
252-
functions should be used here.
253-
inverse: callable, no lambda
254-
Function to transform the data in the inverse pass.
255-
For the adapter to be serializable, this function has to be serializable
256-
as well (see Notes). Therefore, only proper functions and no lambda
257-
functions should be used here.
249+
forward : str or np.ufunc
250+
The name of the NumPy function to use for the forward transformation.
251+
inverse : str or np.ufunc, optional
252+
The name of the NumPy function to use for the inverse transformation.
253+
By default, the inverse is inferred from the forward argument for supported methods.
254+
You can find the supported methods in
255+
:py:const:`~bayesflow.adapters.transforms.NumpyTransform.INVERSE_METHODS`.
258256
predicate : Predicate, optional
259257
Function that indicates which variables should be transformed.
260258
include : str or Sequence of str, optional
@@ -263,12 +261,6 @@ def apply(
263261
Names of variables to exclude from the transform.
264262
**kwargs : dict
265263
Additional keyword arguments passed to the transform.
266-
267-
Notes
268-
-----
269-
Important: This is only serializable if the forward and inverse functions are serializable.
270-
This most likely means you will have to pass the scope that the forward and inverse functions are contained in
271-
to the `custom_objects` argument of the `deserialize` function when deserializing this class.
272264
"""
273265
transform = FilterTransform(
274266
transform_constructor=NumpyTransform,
@@ -388,6 +380,7 @@ def convert_dtype(
388380
exclude: str | Sequence[str] = None,
389381
):
390382
"""Append a :py:class:`~transforms.ConvertDType` transform to the adapter.
383+
See also :py:meth:`~bayesflow.adapters.Adapter.map_dtype`.
391384
392385
Parameters
393386
----------
@@ -525,6 +518,24 @@ def log(self, keys: str | Sequence[str], *, p1: bool = False):
525518
self.transforms.append(transform)
526519
return self
527520

521+
def map_dtype(self, keys: str | Sequence[str], to_dtype: str):
522+
"""Append a :py:class:`~transforms.ConvertDType` transform to the adapter.
523+
See also :py:meth:`~bayesflow.adapters.Adapter.convert_dtype`.
524+
525+
Parameters
526+
----------
527+
keys : str or Sequence of str
528+
The names of the variables to transform.
529+
to_dtype : str
530+
Target dtype
531+
"""
532+
if isinstance(keys, str):
533+
keys = [keys]
534+
535+
transform = MapTransform({key: ConvertDType(to_dtype) for key in keys})
536+
self.transforms.append(transform)
537+
return self
538+
528539
def one_hot(self, keys: str | Sequence[str], num_classes: int):
529540
"""Append a :py:class:`~transforms.OneHot` transform to the adapter.
530541
@@ -555,6 +566,24 @@ def rename(self, from_key: str, to_key: str):
555566
self.transforms.append(Rename(from_key, to_key))
556567
return self
557568

569+
def scale(self, keys: str | Sequence[str], by: float | np.ndarray):
570+
from .transforms import Scale
571+
572+
if isinstance(keys, str):
573+
keys = [keys]
574+
575+
self.transforms.append(MapTransform({key: Scale(scale=by) for key in keys}))
576+
return self
577+
578+
def shift(self, keys: str | Sequence[str], by: float | np.ndarray):
579+
from .transforms import Shift
580+
581+
if isinstance(keys, str):
582+
keys = [keys]
583+
584+
self.transforms.append(MapTransform({key: Shift(shift=by) for key in keys}))
585+
return self
586+
558587
def sqrt(self, keys: str | Sequence[str]):
559588
"""Append an :py:class:`~transforms.Sqrt` transform to the adapter.
560589
@@ -572,9 +601,9 @@ def sqrt(self, keys: str | Sequence[str]):
572601

573602
def standardize(
574603
self,
604+
include: str | Sequence[str] = None,
575605
*,
576606
predicate: Predicate = None,
577-
include: str | Sequence[str] = None,
578607
exclude: str | Sequence[str] = None,
579608
**kwargs,
580609
):
@@ -603,9 +632,9 @@ def standardize(
603632

604633
def to_array(
605634
self,
635+
include: str | Sequence[str] = None,
606636
*,
607637
predicate: Predicate = None,
608-
include: str | Sequence[str] = None,
609638
exclude: str | Sequence[str] = None,
610639
**kwargs,
611640
):

bayesflow/adapters/transforms/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,8 @@
1414
from .numpy_transform import NumpyTransform
1515
from .one_hot import OneHot
1616
from .rename import Rename
17+
from .scale import Scale
18+
from .shift import Shift
1719
from .sqrt import Sqrt
1820
from .standardize import Standardize
1921
from .to_array import ToArray

bayesflow/adapters/transforms/filter_transform.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -29,10 +29,10 @@ class FilterTransform(Transform):
2929

3030
def __init__(
3131
self,
32+
include: str | Sequence[str] = None,
3233
*,
3334
transform_constructor: Callable[..., ElementwiseTransform],
3435
predicate: Predicate = None,
35-
include: str | Sequence[str] = None,
3636
exclude: str | Sequence[str] = None,
3737
**kwargs,
3838
):

bayesflow/adapters/transforms/numpy_transform.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@ class NumpyTransform(ElementwiseTransform):
1717
The name of the NumPy function to apply in the inverse transformation.
1818
"""
1919

20+
#: Dict of `np.ufunc` that support automatic selection of their inverse.
2021
INVERSE_METHODS = {
2122
np.arctan: np.tan,
2223
np.exp: np.log,
Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,27 @@
1+
from keras.saving import (
2+
deserialize_keras_object as deserialize,
3+
register_keras_serializable as serializable,
4+
serialize_keras_object as serialize,
5+
)
6+
import numpy as np
7+
8+
from .elementwise_transform import ElementwiseTransform
9+
10+
11+
@serializable(package="bayesflow.adapters")
12+
class Scale(ElementwiseTransform):
13+
def __init__(self, scale: np.typing.ArrayLike):
14+
self.scale = np.array(scale)
15+
16+
@classmethod
17+
def from_config(cls, config: dict, custom_objects=None) -> "ElementwiseTransform":
18+
return cls(scale=deserialize(config["scale"]))
19+
20+
def get_config(self) -> dict:
21+
return {"scale": serialize(self.scale)}
22+
23+
def forward(self, data: np.ndarray, **kwargs) -> np.ndarray:
24+
return data * self.scale
25+
26+
def inverse(self, data: np.ndarray, **kwargs) -> np.ndarray:
27+
return data / self.scale
Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,27 @@
1+
from keras.saving import (
2+
deserialize_keras_object as deserialize,
3+
register_keras_serializable as serializable,
4+
serialize_keras_object as serialize,
5+
)
6+
import numpy as np
7+
8+
from .elementwise_transform import ElementwiseTransform
9+
10+
11+
@serializable(package="bayesflow.adapters")
12+
class Shift(ElementwiseTransform):
13+
def __init__(self, shift: np.typing.ArrayLike):
14+
self.shift = np.array(shift)
15+
16+
@classmethod
17+
def from_config(cls, config: dict, custom_objects=None) -> "ElementwiseTransform":
18+
return cls(shift=deserialize(config["shift"]))
19+
20+
def get_config(self) -> dict:
21+
return {"shift": serialize(self.shift)}
22+
23+
def forward(self, data: np.ndarray, **kwargs) -> np.ndarray:
24+
return data + self.shift
25+
26+
def inverse(self, data: np.ndarray, **kwargs) -> np.ndarray:
27+
return data - self.shift

bayesflow/approximators/approximator.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,9 @@
1-
import keras
1+
from collections.abc import Mapping
2+
23
import multiprocessing as mp
34

5+
import keras
6+
47
from bayesflow.adapters import Adapter
58
from bayesflow.datasets import OnlineDataset
69
from bayesflow.simulators import Simulator
@@ -19,7 +22,7 @@ def build_adapter(cls, **kwargs) -> Adapter:
1922
# implemented by each respective architecture
2023
raise NotImplementedError
2124

22-
def build_from_data(self, data: dict[str, any]) -> None:
25+
def build_from_data(self, data: Mapping[str, any]) -> None:
2326
self.compute_metrics(**data, stage="training")
2427
self.built = True
2528

@@ -72,7 +75,7 @@ def fit(self, *, dataset: keras.utils.PyDataset = None, simulator: Simulator = N
7275
A dataset containing simulations for training. If provided, `simulator` must be None.
7376
simulator : Simulator, optional
7477
A simulator used to generate a dataset. If provided, `dataset` must be None.
75-
**kwargs : dict
78+
**kwargs
7679
Additional keyword arguments passed to `keras.Model.fit()`, including (see also `build_dataset`):
7780
7881
batch_size : int or None, default='auto'

0 commit comments

Comments
 (0)