Skip to content

Commit b82716b

Browse files
committed
Merge remote-tracking branch 'upstream/dev' into fix-additional-metrics
2 parents 07f7546 + 0c99bd9 commit b82716b

File tree

72 files changed

+4047
-2843
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

72 files changed

+4047
-2843
lines changed

README.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -130,7 +130,7 @@ Many examples from [Bayesian Cognitive Modeling: A Practical Course](https://bay
130130
1. [Linear regression starter example](examples/Linear_Regression_Starter.ipynb)
131131
2. [From ABC to BayesFlow](examples/From_ABC_to_BayesFlow.ipynb)
132132
3. [Two moons starter example](examples/Two_Moons_Starter.ipynb)
133-
4. [Rapid iteration with point estimators](examples/Lotka_Volterra_Point_Estimation_and_Expert_Stats.ipynb)
133+
4. [Rapid iteration with point estimators](examples/Lotka_Volterra_Point_Estimation.ipynb)
134134
5. [SIR model with custom summary network](examples/SIR_Posterior_Estimation.ipynb)
135135
6. [Bayesian experimental design](examples/Bayesian_Experimental_Design.ipynb)
136136
7. [Simple model comparison example](examples/One_Sample_TTest.ipynb)

bayesflow/adapters/adapter.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -820,7 +820,7 @@ def split(self, key: str, *, into: Sequence[str], indices_or_sections: int | Seq
820820

821821
return self
822822

823-
def squeeze(self, keys: str | Sequence[str], *, axis: int | tuple):
823+
def squeeze(self, keys: str | Sequence[str], *, axis: int | Sequence[int]):
824824
"""Append a :py:class:`~transforms.Squeeze` transform to the adapter.
825825
826826
Parameters

bayesflow/adapters/transforms/nan_to_num.py

Lines changed: 10 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -13,12 +13,14 @@ class NanToNum(Transform):
1313
1414
Parameters
1515
----------
16-
default_value : float
17-
Value to substitute wherever data is NaN.
18-
return_mask : bool, default=False
19-
If True, a mask array will be returned under a new key.
20-
mask_prefix : str, default='mask_'
21-
Prefix for the mask key in the output dictionary.
16+
key : str
17+
The variable key to look for in the simulation data dict.
18+
default_value : float, optional
19+
Value to substitute wherever data is NaN. Default is 0.0.
20+
return_mask : bool, optional
21+
If True, a mask array will be returned under a new key. Default is False.
22+
mask_prefix : str, optional
23+
Prefix for the mask key in the output dictionary. Default is 'mask_'.
2224
"""
2325

2426
def __init__(self, key: str, default_value: float = 0.0, return_mask: bool = False, mask_prefix: str = "mask"):
@@ -81,10 +83,10 @@ def inverse(self, data: dict[str, any], **kwargs) -> dict[str, any]:
8183
values = data[self.key]
8284

8385
if not self.return_mask:
84-
values[values == self.default_value] = np.nan # we assume default_value is not in data
86+
# assumes default_value is not in nan
87+
values[values == self.default_value] = np.nan
8588
else:
8689
mask_array = data[self.mask_key].astype(bool)
87-
# Put NaNs where mask is 0
8890
values[~mask_array] = np.nan
8991

9092
data[self.key] = values

bayesflow/adapters/transforms/nnpe.py

Lines changed: 18 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -65,8 +65,8 @@ class NNPE(ElementwiseTransform):
6565
def __init__(
6666
self,
6767
*,
68-
spike_scale: float | np.ndarray | None = None,
69-
slab_scale: float | np.ndarray | None = None,
68+
spike_scale: np.typing.ArrayLike | None = None,
69+
slab_scale: np.typing.ArrayLike | None = None,
7070
per_dimension: bool = True,
7171
seed: int | None = None,
7272
):
@@ -80,14 +80,14 @@ def __init__(
8080
def _resolve_scale(
8181
self,
8282
name: str,
83-
passed: float | np.ndarray | None,
83+
passed: np.typing.ArrayLike | None,
8484
default: float,
8585
data: np.ndarray,
8686
) -> np.ndarray | float:
8787
"""
8888
Determine spike/slab scale:
89-
- If passed is None: Automatic determination via default * std(data) (per‐dimension or global).
90-
- Else: validate & cast passed to the correct shape/type.
89+
- If `passed` is None: Automatic determination via default * std(data) (per‐dimension or global).
90+
- Else: Validate & cast `passed` to the correct shape/type.
9191
9292
Parameters
9393
----------
@@ -103,8 +103,8 @@ def _resolve_scale(
103103
104104
Returns
105105
-------
106-
float or np.ndarray
107-
The resolved scale, either as a scalar (if per_dimension=False) or an 1D array of length data.shape[-1]
106+
np.ndarray
107+
The resolved scale, either as a 0D array (if per_dimension=False) or an 1D array of length data.shape[-1]
108108
(if per_dimension=True).
109109
"""
110110

@@ -119,22 +119,22 @@ def _resolve_scale(
119119

120120
# If no scale is passed, determine scale automatically given the dimensionwise or global std
121121
if passed is None:
122-
return default * std
122+
return np.array(default * std)
123123
# If a scale is passed, check if the passed shape matches the expected shape
124124
else:
125-
if self.per_dimension:
125+
try:
126126
arr = np.asarray(passed, dtype=float)
127-
if arr.shape != expected_shape or arr.ndim != 1:
127+
except Exception as e:
128+
raise TypeError(f"{name}: expected values convertible to float, got {type(passed).__name__}") from e
129+
130+
if self.per_dimension:
131+
if arr.ndim != 1 or arr.shape != expected_shape:
128132
raise ValueError(f"{name}: expected array of shape {expected_shape}, got {arr.shape}")
129133
return arr
130134
else:
131-
try:
132-
scalar = float(passed)
133-
except TypeError:
134-
raise TypeError(f"{name}: expected a scalar convertible to float, got type {type(passed).__name__}")
135-
except ValueError:
136-
raise ValueError(f"{name}: expected a scalar convertible to float, got value {passed!r}")
137-
return scalar
135+
if arr.ndim != 0:
136+
raise ValueError(f"{name}: expected scalar, got array of shape {arr.shape}")
137+
return arr
138138

139139
def forward(self, data: np.ndarray, stage: str = "inference", **kwargs) -> np.ndarray:
140140
"""
@@ -173,7 +173,7 @@ def forward(self, data: np.ndarray, stage: str = "inference", **kwargs) -> np.nd
173173
return data + noise
174174

175175
def inverse(self, data: np.ndarray, **kwargs) -> np.ndarray:
176-
"""Non-invertible transform."""
176+
# Non-invertible transform
177177
return data
178178

179179
def get_config(self) -> dict:

bayesflow/adapters/transforms/random_subsample.py

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -8,8 +8,10 @@ class RandomSubsample(ElementwiseTransform):
88
"""
99
A transform that takes a random subsample of the data within an axis.
1010
11-
Example: adapter.random_subsample("x", sample_size = 3, axis = -1)
11+
Examples
12+
--------
1213
14+
>>> adapter = bf.Adapter().random_subsample("x", sample_size=3, axis=-1)
1315
"""
1416

1517
def __init__(
@@ -20,23 +22,22 @@ def __init__(
2022
super().__init__()
2123
if isinstance(sample_size, float):
2224
if sample_size <= 0 or sample_size >= 1:
23-
ValueError("Sample size as a percentage must be a float between 0 and 1 exclusive. ")
25+
raise ValueError("Sample size as a percentage must be a float between 0 and 1 exclusive. ")
2426
self.sample_size = sample_size
2527
self.axis = axis
2628

2729
def forward(self, data: np.ndarray, **kwargs) -> np.ndarray:
28-
axis = self.axis
29-
max_sample_size = data.shape[axis]
30+
max_sample_size = data.shape[self.axis]
3031

3132
if isinstance(self.sample_size, int):
3233
sample_size = self.sample_size
3334
else:
3435
sample_size = np.round(self.sample_size * max_sample_size)
3536

3637
# random sample without replacement
37-
sample_indices = np.random.permutation(max_sample_size)[0 : sample_size - 1]
38+
sample_indices = np.random.permutation(max_sample_size)[:sample_size]
3839

39-
return np.take(data, sample_indices, axis)
40+
return np.take(data, sample_indices, self.axis)
4041

4142
def inverse(self, data: np.ndarray, **kwargs) -> np.ndarray:
4243
# non invertible transform

bayesflow/adapters/transforms/squeeze.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
import numpy as np
22

3+
from collections.abc import Sequence
34
from bayesflow.utils.serialization import serializable, serialize
45

56
from .elementwise_transform import ElementwiseTransform
@@ -29,8 +30,10 @@ class Squeeze(ElementwiseTransform):
2930
It is recommended to precede this transform with a :class:`~bayesflow.adapters.transforms.ToArray` transform.
3031
"""
3132

32-
def __init__(self, *, axis: int | tuple):
33+
def __init__(self, *, axis: int | Sequence[int]):
3334
super().__init__()
35+
if isinstance(axis, Sequence):
36+
axis = tuple(axis)
3437
self.axis = axis
3538

3639
def get_config(self) -> dict:

bayesflow/adapters/transforms/standardize.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
from collections.abc import Sequence
2+
import warnings
23

34
import numpy as np
45

@@ -69,6 +70,14 @@ def __init__(
6970
):
7071
super().__init__()
7172

73+
if mean is None or std is None:
74+
warnings.warn(
75+
"Dynamic standardization is deprecated and will be removed in later versions."
76+
"Instead, use the standardize argument of the approximator / workflow instance or provide "
77+
"fixed mean and std arguments. You may incur some redundant computations if you keep this transform.",
78+
FutureWarning,
79+
)
80+
7281
self.mean = mean
7382
self.std = std
7483

bayesflow/adapters/transforms/transform.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@ def __repr__(self):
2222

2323
@classmethod
2424
def from_config(cls, config: dict, custom_objects=None):
25+
# noinspection PyArgumentList
2526
return cls(**deserialize(config, custom_objects=custom_objects))
2627

2728
def get_config(self) -> dict:

bayesflow/approximators/approximator.py

Lines changed: 9 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -11,18 +11,16 @@
1111

1212

1313
class Approximator(BackendApproximator):
14-
def build(self, data_shapes: any) -> None:
15-
mock_data = keras.tree.map_structure(keras.ops.zeros, data_shapes)
16-
self.build_from_data(mock_data)
14+
def build(self, data_shapes: dict[str, tuple[int] | dict[str, dict]]) -> None:
15+
raise NotImplementedError
1716

1817
@classmethod
1918
def build_adapter(cls, **kwargs) -> Adapter:
2019
# implemented by each respective architecture
2120
raise NotImplementedError
2221

23-
def build_from_data(self, data: dict[str, any]) -> None:
24-
self.compute_metrics(**filter_kwargs(data, self.compute_metrics), stage="training")
25-
self.built = True
22+
def build_from_data(self, adapted_data: dict[str, any]) -> None:
23+
raise NotImplementedError
2624

2725
@classmethod
2826
def build_dataset(
@@ -61,6 +59,9 @@ def build_dataset(
6159
max_queue_size=max_queue_size,
6260
)
6361

62+
def call(self, *args, **kwargs):
63+
return self.compute_metrics(*args, **kwargs)
64+
6465
def fit(self, *, dataset: keras.utils.PyDataset = None, simulator: Simulator = None, **kwargs):
6566
"""
6667
Trains the approximator on the provided dataset or on-demand data generated from the given simulator.
@@ -132,6 +133,7 @@ def fit(self, *, dataset: keras.utils.PyDataset = None, simulator: Simulator = N
132133
logging.info("Building on a test batch.")
133134
mock_data = dataset[0]
134135
mock_data = keras.tree.map_structure(keras.ops.convert_to_tensor, mock_data)
135-
self.build_from_data(mock_data)
136+
mock_data_shapes = keras.tree.map_structure(keras.ops.shape, mock_data)
137+
self.build(mock_data_shapes)
136138

137139
return super().fit(dataset=dataset, **kwargs)

bayesflow/approximators/backend_approximators/backend_approximator.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6,8 +6,6 @@
66
match keras.backend.backend():
77
case "jax":
88
from .jax_approximator import JAXApproximator as BaseBackendApproximator
9-
case "numpy":
10-
from .numpy_approximator import NumpyApproximator as BaseBackendApproximator
119
case "tensorflow":
1210
from .tensorflow_approximator import TensorFlowApproximator as BaseBackendApproximator
1311
case "torch":

0 commit comments

Comments
 (0)