Skip to content

Commit 869d98f

Browse files
committed
Merge remote-tracking branch 'upstream/dev' into feat-test-compatibility
2 parents 8ed5dab + 057f3fd commit 869d98f

File tree

24 files changed

+258
-137
lines changed

24 files changed

+258
-137
lines changed

bayesflow/adapters/adapter.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -820,7 +820,7 @@ def split(self, key: str, *, into: Sequence[str], indices_or_sections: int | Seq
820820

821821
return self
822822

823-
def squeeze(self, keys: str | Sequence[str], *, axis: int | tuple):
823+
def squeeze(self, keys: str | Sequence[str], *, axis: int | Sequence[int]):
824824
"""Append a :py:class:`~transforms.Squeeze` transform to the adapter.
825825
826826
Parameters

bayesflow/adapters/transforms/nan_to_num.py

Lines changed: 10 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -13,12 +13,14 @@ class NanToNum(Transform):
1313
1414
Parameters
1515
----------
16-
default_value : float
17-
Value to substitute wherever data is NaN.
18-
return_mask : bool, default=False
19-
If True, a mask array will be returned under a new key.
20-
mask_prefix : str, default='mask_'
21-
Prefix for the mask key in the output dictionary.
16+
key : str
17+
The variable key to look for in the simulation data dict.
18+
default_value : float, optional
19+
Value to substitute wherever data is NaN. Default is 0.0.
20+
return_mask : bool, optional
21+
If True, a mask array will be returned under a new key. Default is False.
22+
mask_prefix : str, optional
23+
Prefix for the mask key in the output dictionary. Default is 'mask_'.
2224
"""
2325

2426
def __init__(self, key: str, default_value: float = 0.0, return_mask: bool = False, mask_prefix: str = "mask"):
@@ -81,10 +83,10 @@ def inverse(self, data: dict[str, any], **kwargs) -> dict[str, any]:
8183
values = data[self.key]
8284

8385
if not self.return_mask:
84-
values[values == self.default_value] = np.nan # we assume default_value is not in data
86+
# assumes default_value is not in nan
87+
values[values == self.default_value] = np.nan
8588
else:
8689
mask_array = data[self.mask_key].astype(bool)
87-
# Put NaNs where mask is 0
8890
values[~mask_array] = np.nan
8991

9092
data[self.key] = values

bayesflow/adapters/transforms/nnpe.py

Lines changed: 18 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -65,8 +65,8 @@ class NNPE(ElementwiseTransform):
6565
def __init__(
6666
self,
6767
*,
68-
spike_scale: float | np.ndarray | None = None,
69-
slab_scale: float | np.ndarray | None = None,
68+
spike_scale: np.typing.ArrayLike | None = None,
69+
slab_scale: np.typing.ArrayLike | None = None,
7070
per_dimension: bool = True,
7171
seed: int | None = None,
7272
):
@@ -80,14 +80,14 @@ def __init__(
8080
def _resolve_scale(
8181
self,
8282
name: str,
83-
passed: float | np.ndarray | None,
83+
passed: np.typing.ArrayLike | None,
8484
default: float,
8585
data: np.ndarray,
8686
) -> np.ndarray | float:
8787
"""
8888
Determine spike/slab scale:
89-
- If passed is None: Automatic determination via default * std(data) (per‐dimension or global).
90-
- Else: validate & cast passed to the correct shape/type.
89+
- If `passed` is None: Automatic determination via default * std(data) (per‐dimension or global).
90+
- Else: Validate & cast `passed` to the correct shape/type.
9191
9292
Parameters
9393
----------
@@ -103,8 +103,8 @@ def _resolve_scale(
103103
104104
Returns
105105
-------
106-
float or np.ndarray
107-
The resolved scale, either as a scalar (if per_dimension=False) or an 1D array of length data.shape[-1]
106+
np.ndarray
107+
The resolved scale, either as a 0D array (if per_dimension=False) or an 1D array of length data.shape[-1]
108108
(if per_dimension=True).
109109
"""
110110

@@ -119,22 +119,22 @@ def _resolve_scale(
119119

120120
# If no scale is passed, determine scale automatically given the dimensionwise or global std
121121
if passed is None:
122-
return default * std
122+
return np.array(default * std)
123123
# If a scale is passed, check if the passed shape matches the expected shape
124124
else:
125-
if self.per_dimension:
125+
try:
126126
arr = np.asarray(passed, dtype=float)
127-
if arr.shape != expected_shape or arr.ndim != 1:
127+
except Exception as e:
128+
raise TypeError(f"{name}: expected values convertible to float, got {type(passed).__name__}") from e
129+
130+
if self.per_dimension:
131+
if arr.ndim != 1 or arr.shape != expected_shape:
128132
raise ValueError(f"{name}: expected array of shape {expected_shape}, got {arr.shape}")
129133
return arr
130134
else:
131-
try:
132-
scalar = float(passed)
133-
except TypeError:
134-
raise TypeError(f"{name}: expected a scalar convertible to float, got type {type(passed).__name__}")
135-
except ValueError:
136-
raise ValueError(f"{name}: expected a scalar convertible to float, got value {passed!r}")
137-
return scalar
135+
if arr.ndim != 0:
136+
raise ValueError(f"{name}: expected scalar, got array of shape {arr.shape}")
137+
return arr
138138

139139
def forward(self, data: np.ndarray, stage: str = "inference", **kwargs) -> np.ndarray:
140140
"""
@@ -173,7 +173,7 @@ def forward(self, data: np.ndarray, stage: str = "inference", **kwargs) -> np.nd
173173
return data + noise
174174

175175
def inverse(self, data: np.ndarray, **kwargs) -> np.ndarray:
176-
"""Non-invertible transform."""
176+
# Non-invertible transform
177177
return data
178178

179179
def get_config(self) -> dict:

bayesflow/adapters/transforms/squeeze.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
import numpy as np
22

3+
from collections.abc import Sequence
34
from bayesflow.utils.serialization import serializable, serialize
45

56
from .elementwise_transform import ElementwiseTransform
@@ -29,8 +30,10 @@ class Squeeze(ElementwiseTransform):
2930
It is recommended to precede this transform with a :class:`~bayesflow.adapters.transforms.ToArray` transform.
3031
"""
3132

32-
def __init__(self, *, axis: int | tuple):
33+
def __init__(self, *, axis: int | Sequence[int]):
3334
super().__init__()
35+
if isinstance(axis, Sequence):
36+
axis = tuple(axis)
3437
self.axis = axis
3538

3639
def get_config(self) -> dict:

bayesflow/approximators/backend_approximators/backend_approximator.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6,8 +6,6 @@
66
match keras.backend.backend():
77
case "jax":
88
from .jax_approximator import JAXApproximator as BaseBackendApproximator
9-
case "numpy":
10-
from .numpy_approximator import NumpyApproximator as BaseBackendApproximator
119
case "tensorflow":
1210
from .tensorflow_approximator import TensorFlowApproximator as BaseBackendApproximator
1311
case "torch":

bayesflow/approximators/backend_approximators/jax_approximator.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -56,6 +56,20 @@ def stateless_compute_metrics(
5656
variables and returns both the loss and auxiliary information for
5757
further updates.
5858
59+
Things we do for specifically jax:
60+
61+
1. Accept trainable variables as the first argument
62+
(can be at any position as indicated by the argnum parameter
63+
in autograd, but needs to be an explicit arg)
64+
2. Accept, potentially modify, and return other state variables
65+
3. Return just the loss tensor as the first value
66+
4. Return all other values in a tuple as the second value
67+
68+
This ensures:
69+
70+
1. The function is stateless
71+
2. The function can be differentiated with jax autograd
72+
5973
Parameters
6074
----------
6175
trainable_variables : Any

bayesflow/approximators/backend_approximators/numpy_approximator.py

Lines changed: 0 additions & 38 deletions
This file was deleted.

bayesflow/approximators/continuous_approximator.py

Lines changed: 26 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
import numpy as np
44

55
import keras
6+
import warnings
67

78
from bayesflow.adapters import Adapter
89
from bayesflow.networks import InferenceNetwork, SummaryNetwork
@@ -98,8 +99,8 @@ def build(self, data_shapes: dict[str, tuple[int] | dict[str, dict]]) -> None:
9899
]
99100
self.standardize_layers = {var: Standardization(trainable=False) for var in self.standardize}
100101

101-
# Build all standardization layers, if present
102-
for var, layer in getattr(self, "standardize_layers", {}).items():
102+
# Build all standardization layers
103+
for var, layer in self.standardize_layers.items():
103104
layer.build(data_shapes[var])
104105

105106
self.built = True
@@ -448,7 +449,7 @@ def sample(
448449
conditions = self._prepare_data(conditions, **kwargs)
449450

450451
# Remove any superfluous keys, just retain actual conditions
451-
conditions = {k: v for k, v in conditions.items() if k in ContinuousApproximator.CONDITION_KEYS}
452+
conditions = {k: v for k, v in conditions.items() if k in self.CONDITION_KEYS}
452453

453454
# Sample and undo optional standardization
454455
samples = self._sample(num_samples=num_samples, **conditions, **kwargs)
@@ -485,7 +486,7 @@ def _prepare_data(
485486
ldj_inference = None
486487

487488
# Standardize conditions
488-
for key in ContinuousApproximator.CONDITION_KEYS:
489+
for key in self.CONDITION_KEYS:
489490
if key in self.standardize and key in data:
490491
data[key] = self.standardize_layers[key](data[key])
491492

@@ -514,8 +515,12 @@ def _sample(
514515
summary_variables: Tensor = None,
515516
**kwargs,
516517
) -> Tensor:
517-
if (self.summary_network is None) != (summary_variables is None):
518-
raise ValueError("Summary variables and summary network must be used together.")
518+
if self.summary_network is None:
519+
if summary_variables is not None:
520+
raise ValueError("Cannot use summary variables without a summary network.")
521+
else:
522+
if summary_variables is None:
523+
raise ValueError("Summary variables are required when a summary network is present.")
519524

520525
if self.summary_network is not None:
521526
summary_outputs = self.summary_network(
@@ -539,7 +544,7 @@ def _sample(
539544
batch_shape, conditions=inference_conditions, **filter_kwargs(kwargs, self.inference_network.sample)
540545
)
541546

542-
def summaries(self, data: Mapping[str, np.ndarray], **kwargs) -> np.ndarray:
547+
def summarize(self, data: Mapping[str, np.ndarray], **kwargs) -> np.ndarray:
543548
"""
544549
Computes the learned summary statistics of given summary variables.
545550
@@ -570,6 +575,14 @@ def summaries(self, data: Mapping[str, np.ndarray], **kwargs) -> np.ndarray:
570575

571576
return summaries
572577

578+
def summaries(self, data: Mapping[str, np.ndarray], **kwargs) -> np.ndarray:
579+
"""
580+
.. deprecated:: 2.0.4
581+
`summaries` will be removed in version 2.0.5, it was renamed to `summarize` which should be used instead.
582+
"""
583+
warnings.warn("`summaries` was renamed to `summarize` and will be removed in version 2.0.5.", FutureWarning)
584+
return self.summarize(data=data, **kwargs)
585+
573586
def log_prob(self, data: Mapping[str, np.ndarray], **kwargs) -> np.ndarray:
574587
"""
575588
Computes the log-probability of given data under the model. The `data` dictionary is preprocessed using the
@@ -606,8 +619,12 @@ def _log_prob(
606619
summary_variables: Tensor = None,
607620
**kwargs,
608621
) -> Tensor:
609-
if (self.summary_network is None) != (summary_variables is None):
610-
raise ValueError("Summary variables and summary network must be used together.")
622+
if self.summary_network is None:
623+
if summary_variables is not None:
624+
raise ValueError("Cannot use summary variables without a summary network.")
625+
else:
626+
if summary_variables is None:
627+
raise ValueError("Summary variables are required when a summary network is present.")
611628

612629
if self.summary_network is not None:
613630
summary_outputs = self.summary_network(

bayesflow/approximators/model_comparison_approximator.py

Lines changed: 20 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22

33
import keras
44
import numpy as np
5+
import warnings
56

67
from bayesflow.adapters import Adapter
78
from bayesflow.datasets import OnlineDataset
@@ -92,11 +93,11 @@ def build(self, data_shapes: dict[str, tuple[int] | dict[str, dict]]) -> None:
9293

9394
# Set up standardization layers if requested
9495
if self.standardize == "all":
95-
self.standardize = [var for var in ModelComparisonApproximator.CONDITION_KEYS if var in data_shapes]
96+
self.standardize = [var for var in self.CONDITION_KEYS if var in data_shapes]
9697
self.standardize_layers = {var: Standardization(trainable=False) for var in self.standardize}
9798

9899
# Build all standardization layers
99-
for var, layer in getattr(self, "standardize_layers", {}).items():
100+
for var, layer in self.standardize_layers.items():
100101
layer.build(data_shapes[var])
101102

102103
self.built = True
@@ -242,7 +243,7 @@ def compute_metrics(
242243
def fit(
243244
self,
244245
*,
245-
adapter: Adapter | str = "auto",
246+
adapter: Adapter = "auto",
246247
dataset: keras.utils.PyDataset = None,
247248
simulator: ModelComparisonSimulator = None,
248249
simulators: Sequence[Simulator] = None,
@@ -256,7 +257,7 @@ def fit(
256257
257258
Parameters
258259
----------
259-
adapter : Adapter or str, optional
260+
adapter : Adapter or 'auto', optional
260261
The data adapter that will make the simulated / real outputs neural-network friendly.
261262
dataset : keras.utils.PyDataset, optional
262263
A dataset containing simulations for training. If provided, `simulator` must be None.
@@ -392,19 +393,22 @@ def predict(
392393
conditions = self.adapter(conditions, strict=False, stage="inference", **kwargs)
393394

394395
# Ensure only keys relevant for sampling are present in the conditions dictionary
395-
conditions = {k: v for k, v in conditions.items() if k in ModelComparisonApproximator.CONDITION_KEYS}
396+
conditions = {k: v for k, v in conditions.items() if k in self.CONDITION_KEYS}
396397
conditions = keras.tree.map_structure(keras.ops.convert_to_tensor, conditions)
397398

398399
# Optionally standardize conditions
399-
for key in ModelComparisonApproximator.CONDITION_KEYS:
400+
for key in self.CONDITION_KEYS:
400401
if key in conditions and key in self.standardize:
401402
conditions[key] = self.standardize_layers[key](conditions[key])
402403

403404
output = self._predict(**conditions, **kwargs)
404405

405-
return keras.ops.convert_to_numpy(keras.ops.softmax(output) if probs else output)
406+
if probs:
407+
output = keras.ops.softmax(output)
406408

407-
def summaries(self, data: Mapping[str, np.ndarray], **kwargs) -> np.ndarray:
409+
return keras.ops.convert_to_numpy(output)
410+
411+
def summarize(self, data: Mapping[str, np.ndarray], **kwargs) -> np.ndarray:
408412
"""
409413
Computes the learned summary statistics of given summary variables.
410414
@@ -435,6 +439,14 @@ def summaries(self, data: Mapping[str, np.ndarray], **kwargs) -> np.ndarray:
435439

436440
return summaries
437441

442+
def summaries(self, data: Mapping[str, np.ndarray], **kwargs) -> np.ndarray:
443+
"""
444+
.. deprecated:: 2.0.4
445+
`summaries` will be removed in version 2.0.5, it was renamed to `summarize` which should be used instead.
446+
"""
447+
warnings.warn("`summaries` was renamed to `summarize` and will be removed in version 2.0.5.", FutureWarning)
448+
return self.summarize(data=data, **kwargs)
449+
438450
def _compute_logits(self, classifier_conditions: Tensor) -> Tensor:
439451
"""Helper to compute projected logits from the classifier network."""
440452
logits = self.classifier_network(classifier_conditions)

0 commit comments

Comments
 (0)