Skip to content

Commit 6e5e191

Browse files
committed
Merge branch 'dev' into allow-networks
# Conflicts: # bayesflow/experimental/free_form_flow/free_form_flow.py # bayesflow/networks/consistency_models/consistency_model.py # bayesflow/networks/coupling_flow/coupling_flow.py # bayesflow/networks/flow_matching/flow_matching.py # bayesflow/networks/time_series_network/skip_recurrent.py # bayesflow/networks/time_series_network/time_series_network.py
2 parents a3437fa + 2bf0b53 commit 6e5e191

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

45 files changed

+1964
-779
lines changed

README.md

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -37,14 +37,14 @@ Using the high-level interface is easy, as demonstrated by the minimal working e
3737
import bayesflow as bf
3838

3939
workflow = bf.BasicWorkflow(
40-
inference_network=bf.networks.FlowMatching(),
41-
summary_network=bf.networks.TimeSeriesTransformer(),
40+
inference_network=bf.networks.CouplingFlow(),
41+
summary_network=bf.networks.TimeSeriesNetwork(),
4242
inference_variables=["parameters"],
4343
summary_variables=["observables"],
4444
simulator=bf.simulators.SIR()
4545
)
4646

47-
history = workflow.fit_online(epochs=50, batch_size=32, num_batches_per_epoch=500)
47+
history = workflow.fit_online(epochs=15, batch_size=32, num_batches_per_epoch=200)
4848

4949
diagnostics = workflow.plot_default_diagnostics(test_data=300)
5050
```
@@ -58,6 +58,7 @@ For an in-depth exposition, check out our walkthrough notebooks below.
5858
5. [SIR model with custom summary network](examples/SIR_Posterior_Estimation.ipynb)
5959
6. [Bayesian experimental design](examples/Bayesian_Experimental_Design.ipynb)
6060
7. [Simple model comparison example](examples/One_Sample_TTest.ipynb)
61+
8. [Moving from BayesFlow v1.1 to v2.0](examples/From_BayesFlow_1.1_to_2.0.ipynb)
6162

6263
More tutorials are always welcome! Please consider making a pull request if you have a cool application that you want to contribute.
6364

bayesflow/__init__.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,8 +7,9 @@
77
experimental,
88
networks,
99
simulators,
10-
workflows,
1110
utils,
11+
workflows,
12+
wrappers,
1213
)
1314

1415
from .adapters import Adapter

bayesflow/adapters/adapter.py

Lines changed: 85 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
from collections.abc import MutableSequence, Sequence, Mapping
1+
from collections.abc import Callable, MutableSequence, Sequence, Mapping
22

33
import numpy as np
44

@@ -24,6 +24,7 @@
2424
NumpyTransform,
2525
OneHot,
2626
Rename,
27+
SerializableCustomTransform,
2728
Sqrt,
2829
Standardize,
2930
ToArray,
@@ -246,16 +247,58 @@ def apply(
246247
247248
Parameters
248249
----------
249-
forward: callable, no lambda
250-
Function to transform the data in the forward pass.
250+
forward : str or np.ufunc
251+
The name of the NumPy function to use for the forward transformation.
252+
inverse : str or np.ufunc, optional
253+
The name of the NumPy function to use for the inverse transformation.
254+
By default, the inverse is inferred from the forward argument for supported methods.
255+
You can find the supported methods in
256+
:py:const:`~bayesflow.adapters.transforms.NumpyTransform.INVERSE_METHODS`.
257+
predicate : Predicate, optional
258+
Function that indicates which variables should be transformed.
259+
include : str or Sequence of str, optional
260+
Names of variables to include in the transform.
261+
exclude : str or Sequence of str, optional
262+
Names of variables to exclude from the transform.
263+
**kwargs : dict
264+
Additional keyword arguments passed to the transform.
265+
"""
266+
transform = FilterTransform(
267+
transform_constructor=NumpyTransform,
268+
predicate=predicate,
269+
include=include,
270+
exclude=exclude,
271+
forward=forward,
272+
inverse=inverse,
273+
**kwargs,
274+
)
275+
self.transforms.append(transform)
276+
return self
277+
278+
def apply_serializable(
279+
self,
280+
include: str | Sequence[str] = None,
281+
*,
282+
forward: Callable[[np.ndarray, ...], np.ndarray],
283+
inverse: Callable[[np.ndarray, ...], np.ndarray],
284+
predicate: Predicate = None,
285+
exclude: str | Sequence[str] = None,
286+
**kwargs,
287+
):
288+
"""Append a :py:class:`~transforms.SerializableCustomTransform` to the adapter.
289+
290+
Parameters
291+
----------
292+
forward : function, no lambda
293+
Registered serializable function to transform the data in the forward pass.
251294
For the adapter to be serializable, this function has to be serializable
252295
as well (see Notes). Therefore, only proper functions and no lambda
253-
functions should be used here.
254-
inverse: callable, no lambda
255-
Function to transform the data in the inverse pass.
296+
functions can be used here.
297+
inverse : function, no lambda
298+
Registered serializable function to transform the data in the inverse pass.
256299
For the adapter to be serializable, this function has to be serializable
257300
as well (see Notes). Therefore, only proper functions and no lambda
258-
functions should be used here.
301+
functions can be used here.
259302
predicate : Predicate, optional
260303
Function that indicates which variables should be transformed.
261304
include : str or Sequence of str, optional
@@ -265,14 +308,45 @@ def apply(
265308
**kwargs : dict
266309
Additional keyword arguments passed to the transform.
267310
311+
Raises
312+
------
313+
ValueError
314+
When the provided functions are not registered serializable functions.
315+
268316
Notes
269317
-----
270-
Important: This is only serializable if the forward and inverse functions are serializable.
271-
This most likely means you will have to pass the scope that the forward and inverse functions are contained in
272-
to the `custom_objects` argument of the `deserialize` function when deserializing this class.
318+
Important: The forward and inverse functions have to be registered with Keras.
319+
To do so, use the `@keras.saving.register_keras_serializable` decorator.
320+
They must also be registered (and identical) when loading the adapter
321+
at a later point in time.
322+
323+
Examples
324+
--------
325+
326+
The example below shows how to use the
327+
`keras.saving.register_keras_serializable` decorator to
328+
register functions with Keras. Note that for this simple
329+
example, one usually would use the simpler :py:meth:`apply`
330+
method.
331+
332+
>>> import keras
333+
>>>
334+
>>> @keras.saving.register_keras_serializable("custom")
335+
>>> def forward_fn(x):
336+
>>> return x**2
337+
>>>
338+
>>> @keras.saving.register_keras_serializable("custom")
339+
>>> def inverse_fn(x):
340+
>>> return x**0.5
341+
>>>
342+
>>> adapter = bf.Adapter().apply_serializable(
343+
>>> "x",
344+
>>> forward=forward_fn,
345+
>>> inverse=inverse_fn,
346+
>>> )
273347
"""
274348
transform = FilterTransform(
275-
transform_constructor=NumpyTransform,
349+
transform_constructor=SerializableCustomTransform,
276350
predicate=predicate,
277351
include=include,
278352
exclude=exclude,

bayesflow/adapters/transforms/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
from .one_hot import OneHot
1616
from .rename import Rename
1717
from .scale import Scale
18+
from .serializable_custom_transform import SerializableCustomTransform
1819
from .shift import Shift
1920
from .sqrt import Sqrt
2021
from .standardize import Standardize

bayesflow/adapters/transforms/numpy_transform.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@ class NumpyTransform(ElementwiseTransform):
1717
The name of the NumPy function to apply in the inverse transformation.
1818
"""
1919

20+
#: Dict of `np.ufunc` that support automatic selection of their inverse.
2021
INVERSE_METHODS = {
2122
np.arctan: np.tan,
2223
np.exp: np.log,
Lines changed: 183 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,183 @@
1+
from collections.abc import Callable
2+
import numpy as np
3+
from keras.saving import (
4+
deserialize_keras_object as deserialize,
5+
register_keras_serializable as serializable,
6+
serialize_keras_object as serialize,
7+
get_registered_name,
8+
get_registered_object,
9+
)
10+
from .elementwise_transform import ElementwiseTransform
11+
from ...utils import filter_kwargs
12+
import inspect
13+
14+
15+
@serializable(package="bayesflow.adapters")
16+
class SerializableCustomTransform(ElementwiseTransform):
17+
"""
18+
Transforms a parameter using a pair of registered serializable forward and inverse functions.
19+
20+
Parameters
21+
----------
22+
forward : function, no lambda
23+
Registered serializable function to transform the data in the forward pass.
24+
For the adapter to be serializable, this function has to be serializable
25+
as well (see Notes). Therefore, only proper functions and no lambda
26+
functions can be used here.
27+
inverse : function, no lambda
28+
Function to transform the data in the inverse pass.
29+
For the adapter to be serializable, this function has to be serializable
30+
as well (see Notes). Therefore, only proper functions and no lambda
31+
functions can be used here.
32+
33+
Raises
34+
------
35+
ValueError
36+
When the provided functions are not registered serializable functions.
37+
38+
Notes
39+
-----
40+
Important: The forward and inverse functions have to be registered with Keras.
41+
To do so, use the `@keras.saving.register_keras_serializable` decorator.
42+
They must also be registered (and identical) when loading the adapter
43+
at a later point in time.
44+
45+
"""
46+
47+
def __init__(
48+
self,
49+
*,
50+
forward: Callable[[np.ndarray, ...], np.ndarray],
51+
inverse: Callable[[np.ndarray, ...], np.ndarray],
52+
):
53+
super().__init__()
54+
55+
self._check_serializable(forward, label="forward")
56+
self._check_serializable(inverse, label="inverse")
57+
self._forward = forward
58+
self._inverse = inverse
59+
60+
@classmethod
61+
def _check_serializable(cls, function, label=""):
62+
GENERAL_EXAMPLE_CODE = (
63+
"The example code below shows the structure of a correctly decorated function:\n\n"
64+
"```\n"
65+
"import keras\n\n"
66+
"@keras.saving.register_keras_serializable('custom')\n"
67+
f"def my_{label}(...):\n"
68+
" [your code goes here...]\n"
69+
"```\n"
70+
)
71+
if function is None:
72+
raise TypeError(
73+
f"'{label}' must be a registered serializable function, was 'NoneType'.\n{GENERAL_EXAMPLE_CODE}"
74+
)
75+
registered_name = get_registered_name(function)
76+
# check if function is a lambda function
77+
if registered_name == "<lambda>":
78+
raise ValueError(
79+
f"The provided function for '{label}' is a lambda function, "
80+
"which cannot be serialized. "
81+
"Please provide a registered serializable function by using the "
82+
"@keras.saving.register_keras_serializable decorator."
83+
f"\n{GENERAL_EXAMPLE_CODE}"
84+
)
85+
if inspect.ismethod(function):
86+
raise ValueError(
87+
f"The provided value for '{label}' is a method, not a function. "
88+
"Methods cannot be serialized separately from their classes. "
89+
"Please provide a registered serializable function instead by "
90+
"moving the functionality to a function (i.e., outside of the class) and "
91+
"using the @keras.saving.register_keras_serializable decorator."
92+
f"\n{GENERAL_EXAMPLE_CODE}"
93+
)
94+
registered_object_for_name = get_registered_object(registered_name)
95+
if registered_object_for_name is None:
96+
try:
97+
source_max_lines = 5
98+
function_source_code = inspect.getsource(function).split("\n")
99+
if len(function_source_code) > source_max_lines:
100+
function_source_code = function_source_code[:source_max_lines] + [" [...]"]
101+
102+
example_code = "For your provided function, this would look like this:\n\n"
103+
example_code += "\n".join(
104+
["```", "import keras\n", "@keras.saving.register_keras_serializable('custom')"]
105+
+ function_source_code
106+
+ ["```"]
107+
)
108+
except OSError:
109+
example_code = GENERAL_EXAMPLE_CODE
110+
raise ValueError(
111+
f"The provided function for '{label}' is not registered with Keras.\n"
112+
"Please register the function using the "
113+
"@keras.saving.register_keras_serializable decorator.\n"
114+
f"{example_code}"
115+
)
116+
if registered_object_for_name is not function:
117+
raise ValueError(
118+
f"The provided function for '{label}' does not match the function "
119+
f"registered under its name '{registered_name}'. "
120+
f"(registered function: {registered_object_for_name}, provided function: {function}). "
121+
)
122+
123+
@classmethod
124+
def from_config(cls, config: dict, custom_objects=None) -> "SerializableCustomTransform":
125+
if get_registered_object(config["forward"]["config"], custom_objects) is None:
126+
provided_function_msg = ""
127+
if config["_forward_source_code"]:
128+
provided_function_msg = (
129+
f"\nThe originally provided function was:\n\n```\n{config['_forward_source_code']}\n```"
130+
)
131+
raise TypeError(
132+
"\n\nPLEASE READ HERE:\n"
133+
"-----------------\n"
134+
"The forward function that was provided as `forward` "
135+
"is not registered with Keras, making deserialization impossible. "
136+
f"Please ensure that it is registered as '{config['forward']['config']}' and identical to the original "
137+
"function before loading your model."
138+
f"{provided_function_msg}"
139+
)
140+
if get_registered_object(config["inverse"]["config"], custom_objects) is None:
141+
provided_function_msg = ""
142+
if config["_inverse_source_code"]:
143+
provided_function_msg = (
144+
f"\nThe originally provided function was:\n\n```\n{config['_inverse_source_code']}\n```"
145+
)
146+
raise TypeError(
147+
"\n\nPLEASE READ HERE:\n"
148+
"-----------------\n"
149+
"The inverse function that was provided as `inverse` "
150+
"is not registered with Keras, making deserialization impossible. "
151+
f"Please ensure that it is registered as '{config['inverse']['config']}' and identical to the original "
152+
"function before loading your model."
153+
f"{provided_function_msg}"
154+
)
155+
forward = deserialize(config["forward"], custom_objects)
156+
inverse = deserialize(config["inverse"], custom_objects)
157+
return cls(
158+
forward=forward,
159+
inverse=inverse,
160+
)
161+
162+
def get_config(self) -> dict:
163+
forward_source_code = inverse_source_code = None
164+
try:
165+
forward_source_code = inspect.getsource(self._forward)
166+
inverse_source_code = inspect.getsource(self._inverse)
167+
except OSError:
168+
pass
169+
return {
170+
"forward": serialize(self._forward),
171+
"inverse": serialize(self._inverse),
172+
"_forward_source_code": forward_source_code,
173+
"_inverse_source_code": inverse_source_code,
174+
}
175+
176+
def forward(self, data: np.ndarray, **kwargs) -> np.ndarray:
177+
# filter kwargs so that other transform args like batch_size, strict, ... are not passed through
178+
kwargs = filter_kwargs(kwargs, self._forward)
179+
return self._forward(data, **kwargs)
180+
181+
def inverse(self, data: np.ndarray, **kwargs) -> np.ndarray:
182+
kwargs = filter_kwargs(kwargs, self._inverse)
183+
return self._inverse(data, **kwargs)

bayesflow/diagnostics/plots/calibration_ecdf.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@ def calibration_ecdf(
1919
figsize: Sequence[float] = None,
2020
label_fontsize: int = 16,
2121
legend_fontsize: int = 14,
22+
legend_location: str = "upper right",
2223
title_fontsize: int = 18,
2324
tick_fontsize: int = 12,
2425
rank_ecdf_color: str = "#132a70",
@@ -184,7 +185,7 @@ def calibration_ecdf(
184185

185186
for ax, title in zip(plot_data["axes"].flat, titles):
186187
ax.fill_between(z, L, U, color=fill_color, alpha=0.2, label=rf"{int((1 - alpha) * 100)}$\%$ Confidence Bands")
187-
ax.legend(fontsize=legend_fontsize)
188+
ax.legend(fontsize=legend_fontsize, loc=legend_location)
188189
ax.set_title(title, fontsize=title_fontsize)
189190

190191
prettify_subplots(plot_data["axes"], num_subplots=plot_data["num_variables"], tick_fontsize=tick_fontsize)

bayesflow/diagnostics/plots/calibration_ecdf_from_quantiles.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@ def calibration_ecdf_from_quantiles(
1919
figsize: Sequence[float] = None,
2020
label_fontsize: int = 16,
2121
legend_fontsize: int = 14,
22+
legend_location: str = "upper right",
2223
title_fontsize: int = 18,
2324
tick_fontsize: int = 12,
2425
rank_ecdf_color: str = "#132a70",
@@ -173,7 +174,7 @@ def calibration_ecdf_from_quantiles(
173174
alpha=0.2,
174175
label=rf"{int((1 - alpha) * 100)}$\%$ Confidence Bands" + "\n(pointwise)",
175176
)
176-
ax.legend(fontsize=legend_fontsize)
177+
ax.legend(fontsize=legend_fontsize, loc=legend_location)
177178
ax.set_title(title, fontsize=title_fontsize)
178179

179180
prettify_subplots(plot_data["axes"], num_subplots=plot_data["num_variables"], tick_fontsize=tick_fontsize)

0 commit comments

Comments
 (0)