Skip to content

Commit 051f998

Browse files
committed
Merge remote-tracking branch 'origin/dev' into dev
2 parents 1c809a7 + 5ac7c99 commit 051f998

File tree

62 files changed

+5112
-4414
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

62 files changed

+5112
-4414
lines changed

.pre-commit-config.yaml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414
repos:
1515
- repo: https://github.com/astral-sh/ruff-pre-commit
1616
# Ruff version.
17-
rev: v0.4.9
17+
rev: v0.9.6
1818
hooks:
1919
# Run the linter.
2020
- id: ruff
@@ -28,7 +28,7 @@ repos:
2828
# - id: nbqa-ruff
2929
# args: [--ignore=E402] # E402: module level import not at top of file
3030
- repo: https://github.com/pre-commit/pre-commit-hooks
31-
rev: v4.4.0
31+
rev: v5.0.0
3232
hooks:
3333
# A bunch of other pre-defined hooks.
3434
- id: check-yaml

README.md

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -93,12 +93,14 @@ conda env create --file environment.yaml --name bayesflow
9393

9494
Check out some of our walk-through notebooks below. We are actively working on porting all notebooks to the new interface so more will be available soon!
9595

96-
1. [Two moons starter toy example](examples/TwoMoons_StarterNotebook.ipynb)
97-
2. [Linear regression](examples/Linear_Regression.ipynb)
98-
3. [Bayesian experimental design](examples/Bayesian_Experimental_Design.ipynb)
99-
4. [SIR model with custom summary network](examples/SIR_PosteriorEstimation.ipynb)
96+
1. [Linear regression starter example](examples/Linear_Regression_Starter.ipynb)
97+
2. [From ABC to BayesFlow](examples/From_ABC_to_BayesFlow.ipynb)
98+
3. [Two moons starter example](examples/Two_Moons_Starter.ipynb)
99+
4. [SIR model with custom summary network](examples/SIR_Posterior_Estimation.ipynb)
100100
5. [Hyperparameter optimization](examples/Hyperparameter_Optimization.ipynb)
101-
6. Coming soon...
101+
6. [Bayesian experimental design](examples/Bayesian_Experimental_Design.ipynb)
102+
7. [Simple model comparison example (One-Sample T-Test)](examples/One_Sample_TTest.ipynb)
103+
8. More coming soon...
102104

103105
## Documentation \& Help
104106

bayesflow/adapters/transforms/broadcast.py

Lines changed: 10 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -79,12 +79,19 @@ def __init__(
7979

8080
@classmethod
8181
def from_config(cls, config: dict, custom_objects=None) -> "Broadcast":
82+
# Deserialize turns tuples to lists, undo it if necessary
83+
exclude = deserialize(config["exclude"], custom_objects)
84+
exclude = tuple(exclude) if isinstance(exclude, list) else exclude
85+
expand = deserialize(config["expand"], custom_objects)
86+
expand = tuple(expand) if isinstance(expand, list) else expand
87+
squeeze = deserialize(config["squeeze"], custom_objects)
88+
squeeze = tuple(squeeze) if isinstance(squeeze, list) else squeeze
8289
return cls(
8390
keys=deserialize(config["keys"], custom_objects),
8491
to=deserialize(config["to"], custom_objects),
85-
expand=deserialize(config["expand"], custom_objects),
86-
exclude=deserialize(config["exclude"], custom_objects),
87-
squeeze=deserialize(config["squeeze"], custom_objects),
92+
expand=expand,
93+
exclude=exclude,
94+
squeeze=squeeze,
8895
)
8996

9097
def get_config(self) -> dict:

bayesflow/adapters/transforms/concatenate.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -30,26 +30,28 @@ class Concatenate(Transform):
3030
)
3131
"""
3232

33-
def __init__(self, keys: Sequence[str], *, into: str, axis: int = -1):
33+
def __init__(self, keys: Sequence[str], *, into: str, axis: int = -1, _indices: list | None = None):
3434
self.keys = keys
3535
self.into = into
3636
self.axis = axis
3737

38-
self.indices = None
38+
self.indices = _indices
3939

4040
@classmethod
4141
def from_config(cls, config: dict, custom_objects=None) -> "Concatenate":
4242
return cls(
4343
keys=deserialize(config["keys"], custom_objects),
4444
into=deserialize(config["into"], custom_objects),
4545
axis=deserialize(config["axis"], custom_objects),
46+
_indices=deserialize(config["indices"], custom_objects),
4647
)
4748

4849
def get_config(self) -> dict:
4950
return {
5051
"keys": serialize(self.keys),
5152
"into": serialize(self.into),
5253
"axis": serialize(self.axis),
54+
"indices": serialize(self.indices),
5355
}
5456

5557
def forward(self, data: dict[str, any], *, strict: bool = True, **kwargs) -> dict[str, any]:

bayesflow/adapters/transforms/filter_transform.py

Lines changed: 25 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,8 @@
44
import numpy as np
55
from keras.saving import (
66
deserialize_keras_object as deserialize,
7+
get_registered_name,
8+
get_registered_object,
79
register_keras_serializable as serializable,
810
serialize_keras_object as serialize,
911
)
@@ -79,21 +81,33 @@ def extra_repr(self) -> str:
7981

8082
@classmethod
8183
def from_config(cls, config: dict, custom_objects=None) -> "Transform":
82-
def transform_constructor(*args, **kwargs):
83-
raise RuntimeError(
84-
"Instantiating new elementwise transforms on a deserialized FilterTransform is not yet supported (and"
85-
"may never be). As a work-around, you can manually register the elementwise transform constructor after"
86-
"deserialization:\n"
87-
"obj = deserialize(config)\n"
88-
"obj.transform_constructor = MyElementwiseTransform"
89-
)
90-
84+
transform_constructor = get_registered_object(config["transform_constructor"])
85+
try:
86+
kwargs = deserialize(config["kwargs"])
87+
except TypeError as e:
88+
if transform_constructor.__name__ == "LambdaTransform":
89+
raise TypeError(
90+
"LambdaTransform (created by Adapter.apply) could not be deserialized.\n"
91+
"This is probably because the custom transform functions `forward` and "
92+
"`backward` from `Adapter.apply` were not passed as `custom_objects`.\n"
93+
"For example, if your adapter uses\n"
94+
"`Adapter.apply(forward=forward_transform, inverse=inverse_transform)`,\n"
95+
"you have to pass\n"
96+
'`custom_objects={"forward_transform": forward_transform, '
97+
'"inverse_transform": inverse_transform}`\n'
98+
"to the function you use to load the serialized object."
99+
) from e
100+
raise TypeError(
101+
"The transform could not be deserialized properly. "
102+
"The most likely reason is that some classes or functions "
103+
"are not known during deserialization. Please pass them as `custom_objects`."
104+
) from e
91105
instance = cls(
92106
transform_constructor=transform_constructor,
93107
predicate=deserialize(config["predicate"], custom_objects),
94108
include=deserialize(config["include"], custom_objects),
95109
exclude=deserialize(config["exclude"], custom_objects),
96-
**config["kwargs"],
110+
**kwargs,
97111
)
98112

99113
instance.transform_map = deserialize(config["transform_map"])
@@ -102,6 +116,7 @@ def transform_constructor(*args, **kwargs):
102116

103117
def get_config(self) -> dict:
104118
return {
119+
"transform_constructor": get_registered_name(self.transform_constructor),
105120
"predicate": serialize(self.predicate),
106121
"include": serialize(self.include),
107122
"exclude": serialize(self.exclude),

bayesflow/adapters/transforms/standardize.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -41,10 +41,14 @@ def __init__(
4141

4242
@classmethod
4343
def from_config(cls, config: dict, custom_objects=None) -> "Standardize":
44+
# Deserialize turns tuples to lists, undo it if necessary
45+
deserialized_axis = deserialize(config["axis"], custom_objects)
46+
if isinstance(deserialized_axis, list):
47+
deserialized_axis = tuple(deserialized_axis)
4448
return cls(
4549
mean=deserialize(config["mean"], custom_objects),
4650
std=deserialize(config["std"], custom_objects),
47-
axis=deserialize(config["axis"], custom_objects),
51+
axis=deserialized_axis,
4852
momentum=deserialize(config["momentum"], custom_objects),
4953
)
5054

bayesflow/approximators/continuous_approximator.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -140,6 +140,8 @@ def sample(
140140
**kwargs,
141141
) -> dict[str, np.ndarray]:
142142
conditions = self.adapter(conditions, strict=False, stage="inference", **kwargs)
143+
# at inference time, inference_variables are estimated by the networks and thus ignored in conditions
144+
conditions.pop("inference_variables", None)
143145
conditions = keras.tree.map_structure(keras.ops.convert_to_tensor, conditions)
144146
conditions = {"inference_variables": self._sample(num_samples=num_samples, **conditions, **kwargs)}
145147
conditions = keras.tree.map_structure(keras.ops.convert_to_numpy, conditions)

bayesflow/approximators/model_comparison_approximator.py

Lines changed: 44 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
from collections.abc import Mapping, Sequence
22

33
import keras
4+
import numpy as np
45
from keras.saving import (
56
deserialize_keras_object as deserialize,
67
register_keras_serializable as serializable,
@@ -198,3 +199,46 @@ def get_config(self):
198199
}
199200

200201
return base_config | config
202+
203+
def predict(
204+
self,
205+
*,
206+
conditions: dict[str, np.ndarray],
207+
logits: bool = False,
208+
**kwargs,
209+
) -> np.ndarray:
210+
conditions = self.adapter(conditions, strict=False, stage="inference", **kwargs)
211+
# at inference time, model_indices are predicted by the networks and thus ignored in conditions
212+
conditions.pop("model_indices", None)
213+
conditions = keras.tree.map_structure(keras.ops.convert_to_tensor, conditions)
214+
215+
output = self._predict(**conditions, **kwargs)
216+
217+
if not logits:
218+
output = keras.ops.softmax(output)
219+
220+
output = keras.ops.convert_to_numpy(output)
221+
222+
return output
223+
224+
def _predict(self, classifier_conditions: Tensor = None, summary_variables: Tensor = None, **kwargs) -> Tensor:
225+
if self.summary_network is None:
226+
if summary_variables is not None:
227+
raise ValueError("Cannot use summary variables without a summary network.")
228+
else:
229+
if summary_variables is None:
230+
raise ValueError("Summary variables are required when a summary network is present")
231+
232+
summary_outputs = self.summary_network(
233+
summary_variables, **filter_kwargs(kwargs, self.summary_network.call)
234+
)
235+
236+
if classifier_conditions is None:
237+
classifier_conditions = summary_outputs
238+
else:
239+
classifier_conditions = keras.ops.concatenate([classifier_conditions, summary_outputs], axis=1)
240+
241+
output = self.classifier_network(classifier_conditions)
242+
output = self.logits_projector(output)
243+
244+
return output

bayesflow/datasets/offline_dataset.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,9 @@
66

77
class OfflineDataset(keras.utils.PyDataset):
88
"""
9-
A dataset that is pre-simulated and stored in memory.
9+
A dataset that is pre-simulated and stored in memory. When storing and loading data from disk, it is recommended to
10+
save any pre-simulated data in raw form and create the `OfflineDataset` object only after loading in the raw data.
11+
See the `DiskDataset` class for handling large datasets that are split into multiple smaller files.
1012
"""
1113

1214
def __init__(self, data: dict[str, np.ndarray], batch_size: int, adapter: Adapter | None, **kwargs):
Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
11
from .calibration_error import calibration_error
22
from .posterior_contraction import posterior_contraction
33
from .root_mean_squared_error import root_mean_squared_error
4+
from .expected_calibration_error import expected_calibration_error

0 commit comments

Comments
 (0)