Skip to content

Commit c421f92

Browse files
committed
Merge remote-tracking branch 'upstream/dev' into feat-adapter-nested
2 parents c145fe3 + a611f70 commit c421f92

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

47 files changed

+3725
-2701
lines changed

README.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -130,7 +130,7 @@ Many examples from [Bayesian Cognitive Modeling: A Practical Course](https://bay
130130
1. [Linear regression starter example](examples/Linear_Regression_Starter.ipynb)
131131
2. [From ABC to BayesFlow](examples/From_ABC_to_BayesFlow.ipynb)
132132
3. [Two moons starter example](examples/Two_Moons_Starter.ipynb)
133-
4. [Rapid iteration with point estimators](examples/Lotka_Volterra_Point_Estimation_and_Expert_Stats.ipynb)
133+
4. [Rapid iteration with point estimators](examples/Lotka_Volterra_Point_Estimation.ipynb)
134134
5. [SIR model with custom summary network](examples/SIR_Posterior_Estimation.ipynb)
135135
6. [Bayesian experimental design](examples/Bayesian_Experimental_Design.ipynb)
136136
7. [Simple model comparison example](examples/One_Sample_TTest.ipynb)

bayesflow/adapters/transforms/random_subsample.py

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -8,8 +8,10 @@ class RandomSubsample(ElementwiseTransform):
88
"""
99
A transform that takes a random subsample of the data within an axis.
1010
11-
Example: adapter.random_subsample("x", sample_size = 3, axis = -1)
11+
Examples
12+
--------
1213
14+
>>> adapter = bf.Adapter().random_subsample("x", sample_size=3, axis=-1)
1315
"""
1416

1517
def __init__(
@@ -20,23 +22,22 @@ def __init__(
2022
super().__init__()
2123
if isinstance(sample_size, float):
2224
if sample_size <= 0 or sample_size >= 1:
23-
ValueError("Sample size as a percentage must be a float between 0 and 1 exclusive. ")
25+
raise ValueError("Sample size as a percentage must be a float between 0 and 1 exclusive. ")
2426
self.sample_size = sample_size
2527
self.axis = axis
2628

2729
def forward(self, data: np.ndarray, **kwargs) -> np.ndarray:
28-
axis = self.axis
29-
max_sample_size = data.shape[axis]
30+
max_sample_size = data.shape[self.axis]
3031

3132
if isinstance(self.sample_size, int):
3233
sample_size = self.sample_size
3334
else:
3435
sample_size = np.round(self.sample_size * max_sample_size)
3536

3637
# random sample without replacement
37-
sample_indices = np.random.permutation(max_sample_size)[0 : sample_size - 1]
38+
sample_indices = np.random.permutation(max_sample_size)[:sample_size]
3839

39-
return np.take(data, sample_indices, axis)
40+
return np.take(data, sample_indices, self.axis)
4041

4142
def inverse(self, data: np.ndarray, **kwargs) -> np.ndarray:
4243
# non invertible transform

bayesflow/adapters/transforms/standardize.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
from collections.abc import Sequence
2+
import warnings
23

34
import numpy as np
45

@@ -69,6 +70,14 @@ def __init__(
6970
):
7071
super().__init__()
7172

73+
if mean is None or std is None:
74+
warnings.warn(
75+
"Dynamic standardization is deprecated and will be removed in later versions."
76+
"Instead, use the standardize argument of the approximator / workflow instance or provide "
77+
"fixed mean and std arguments. You may incur some redundant computations if you keep this transform.",
78+
FutureWarning,
79+
)
80+
7281
self.mean = mean
7382
self.std = std
7483

bayesflow/adapters/transforms/transform.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@ def __repr__(self):
2222

2323
@classmethod
2424
def from_config(cls, config: dict, custom_objects=None):
25+
# noinspection PyArgumentList
2526
return cls(**deserialize(config, custom_objects=custom_objects))
2627

2728
def get_config(self) -> dict:

bayesflow/approximators/approximator.py

Lines changed: 9 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -11,18 +11,16 @@
1111

1212

1313
class Approximator(BackendApproximator):
14-
def build(self, data_shapes: any) -> None:
15-
mock_data = keras.tree.map_structure(keras.ops.zeros, data_shapes)
16-
self.build_from_data(mock_data)
14+
def build(self, data_shapes: dict[str, tuple[int] | dict[str, dict]]) -> None:
15+
raise NotImplementedError
1716

1817
@classmethod
1918
def build_adapter(cls, **kwargs) -> Adapter:
2019
# implemented by each respective architecture
2120
raise NotImplementedError
2221

23-
def build_from_data(self, data: dict[str, any]) -> None:
24-
self.compute_metrics(**filter_kwargs(data, self.compute_metrics), stage="training")
25-
self.built = True
22+
def build_from_data(self, adapted_data: dict[str, any]) -> None:
23+
raise NotImplementedError
2624

2725
@classmethod
2826
def build_dataset(
@@ -61,6 +59,9 @@ def build_dataset(
6159
max_queue_size=max_queue_size,
6260
)
6361

62+
def call(self, *args, **kwargs):
63+
return self.compute_metrics(*args, **kwargs)
64+
6465
def fit(self, *, dataset: keras.utils.PyDataset = None, simulator: Simulator = None, **kwargs):
6566
"""
6667
Trains the approximator on the provided dataset or on-demand data generated from the given simulator.
@@ -132,6 +133,7 @@ def fit(self, *, dataset: keras.utils.PyDataset = None, simulator: Simulator = N
132133
logging.info("Building on a test batch.")
133134
mock_data = dataset[0]
134135
mock_data = keras.tree.map_structure(keras.ops.convert_to_tensor, mock_data)
135-
self.build_from_data(mock_data)
136+
mock_data_shapes = keras.tree.map_structure(keras.ops.shape, mock_data)
137+
self.build(mock_data_shapes)
136138

137139
return super().fit(dataset=dataset, **kwargs)

0 commit comments

Comments
 (0)