Skip to content

Commit 0bf125b

Browse files
committed
Merge branch 'dev' into allow-networks
2 parents 755f043 + 8482926 commit 0bf125b

File tree

14 files changed

+358
-15
lines changed

14 files changed

+358
-15
lines changed

bayesflow/adapters/adapter.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -667,6 +667,18 @@ def shift(self, keys: str | Sequence[str], by: float | np.ndarray):
667667
self.transforms.append(MapTransform({key: Shift(shift=by) for key in keys}))
668668
return self
669669

670+
def split(self, key: str, *, into: Sequence[str], indices_or_sections: int | Sequence[int] = None, axis: int = -1):
671+
from .transforms import Split
672+
673+
if isinstance(into, str):
674+
transform = Rename(key, into)
675+
else:
676+
transform = Split(key, into, indices_or_sections, axis)
677+
678+
self.transforms.append(transform)
679+
680+
return self
681+
670682
def sqrt(self, keys: str | Sequence[str]):
671683
"""Append an :py:class:`~transforms.Sqrt` transform to the adapter.
672684

bayesflow/adapters/transforms/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
from .scale import Scale
1818
from .serializable_custom_transform import SerializableCustomTransform
1919
from .shift import Shift
20+
from .split import Split
2021
from .sqrt import Sqrt
2122
from .standardize import Standardize
2223
from .to_array import ToArray

bayesflow/adapters/transforms/concatenate.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@ class Concatenate(Transform):
2323
Examples
2424
--------
2525
Suppose you have a simulator that generates variables "beta" and "sigma" from priors and then observation
26-
variables "x" and "y". We can then use concatonate in the following way
26+
variables "x" and "y". We can then use concatenate in the following way
2727
2828
>>> adapter = (
2929
bf.Adapter()
Lines changed: 112 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,112 @@
1+
from collections.abc import Sequence
2+
import numpy as np
3+
4+
from keras.saving import (
5+
deserialize_keras_object as deserialize,
6+
register_keras_serializable as serializable,
7+
serialize_keras_object as serialize,
8+
)
9+
10+
from .transform import Transform
11+
12+
13+
@serializable(package="bayesflow.adapters")
14+
class Split(Transform):
15+
"""This is the effective inverse of the :py:class:`~Concatenate` Transform.
16+
17+
Parameters
18+
----------
19+
key : str
20+
The key to split in the forward transform.
21+
into: Sequence[str]
22+
The names of each split after the forward transform.
23+
indices_or_sections : int | Sequence[int], optional, default: None
24+
The number of sections or indices to split on. If not given, will split evenly into len(into) parts.
25+
axis: int, optional, default: -1
26+
The axis to split on.
27+
"""
28+
29+
def __init__(self, key: str, into: Sequence[str], indices_or_sections: int | Sequence[int] = None, axis: int = -1):
30+
self.axis = axis
31+
self.key = key
32+
self.into = into
33+
34+
if indices_or_sections is None:
35+
indices_or_sections = len(into)
36+
37+
self.indices_or_sections = indices_or_sections
38+
39+
@classmethod
40+
def from_config(cls, config: dict, custom_objects=None) -> "Split":
41+
return cls(
42+
key=deserialize(config["key"], custom_objects),
43+
into=deserialize(config["into"], custom_objects),
44+
indices_or_sections=deserialize(config["indices_or_sections"], custom_objects),
45+
axis=deserialize(config["axis"], custom_objects),
46+
)
47+
48+
def get_config(self) -> dict:
49+
return {
50+
"key": serialize(self.key),
51+
"into": serialize(self.into),
52+
"indices_or_sections": serialize(self.indices_or_sections),
53+
"axis": serialize(self.axis),
54+
}
55+
56+
def forward(self, data: dict[str, np.ndarray], strict: bool = True, **kwargs) -> dict[str, np.ndarray]:
57+
# avoid side effects
58+
data = data.copy()
59+
60+
if strict and self.key not in data:
61+
raise KeyError(self.key)
62+
elif self.key not in data:
63+
# we cannot produce a result, but also don't have to
64+
return data
65+
66+
splits = np.split(data.pop(self.key), self.indices_or_sections, axis=self.axis)
67+
68+
if len(splits) != len(self.into):
69+
raise ValueError(f"Requested {len(self.into)} splits, but produced {len(splits)}.")
70+
71+
for key, split in zip(self.into, splits):
72+
data[key] = split
73+
74+
return data
75+
76+
def inverse(self, data: dict[str, np.ndarray], strict: bool = False, **kwargs) -> dict[str, np.ndarray]:
77+
# avoid side effects
78+
data = data.copy()
79+
80+
required_keys = set(self.into)
81+
available_keys = set(data.keys())
82+
common_keys = available_keys & required_keys
83+
missing_keys = required_keys - available_keys
84+
85+
if strict and missing_keys:
86+
# invalid call
87+
raise KeyError(f"Missing keys: {missing_keys!r}")
88+
elif missing_keys:
89+
# we cannot produce a result, but should still remove the keys
90+
for key in common_keys:
91+
data.pop(key)
92+
93+
return data
94+
95+
# remove each part
96+
splits = [data.pop(key) for key in self.into]
97+
98+
# concatenate them all
99+
result = np.concatenate(splits, axis=self.axis)
100+
101+
# store the result
102+
data[self.key] = result
103+
104+
return data
105+
106+
def extra_repr(self) -> str:
107+
result = "[" + ", ".join(map(repr, self.key)) + "] -> " + repr(self.into)
108+
109+
if self.axis != -1:
110+
result += f", axis={self.axis}"
111+
112+
return result

bayesflow/simulators/__init__.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,8 +12,16 @@
1212
from .simulator import Simulator
1313

1414
from .benchmark_simulators import (
15+
BernoulliGLM,
16+
BernoulliGLMRaw,
17+
GaussianLinear,
18+
GaussianLinearUniform,
19+
GaussianMixture,
20+
InverseKinematics,
1521
LotkaVolterra,
1622
SIR,
23+
SLCP,
24+
SLCPDistractors,
1725
TwoMoons,
1826
)
1927

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,11 @@
1+
from .bernoulli_glm import BernoulliGLM
2+
from .bernoulli_glm_raw import BernoulliGLMRaw
3+
from .gaussian_linear import GaussianLinear
4+
from .gaussian_linear_uniform import GaussianLinearUniform
5+
from .gaussian_mixture import GaussianMixture
6+
from .inverse_kinematics import InverseKinematics
17
from .lotka_volterra import LotkaVolterra
28
from .sir import SIR
9+
from .slcp import SLCP
10+
from .slcp_distractors import SLCPDistractors
311
from .two_moons import TwoMoons

bayesflow/simulators/benchmark_simulators/gaussian_linear.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -75,5 +75,10 @@ def observation_model(self, params: np.ndarray):
7575
# Generate prior predictive samples, possibly a single if n_obs is None
7676
if self.n_obs is None:
7777
return self.rng.normal(loc=params, scale=self.obs_scale)
78-
x = self.rng.normal(loc=params, scale=self.obs_scale, size=(self.n_obs, params.shape[0], params.shape[1]))
79-
return np.transpose(x, (1, 0, 2))
78+
if params.ndim == 2:
79+
# batched sampling with n_obs
80+
x = self.rng.normal(loc=params, scale=self.obs_scale, size=(self.n_obs, params.shape[0], params.shape[1]))
81+
return np.transpose(x, (1, 0, 2))
82+
elif params.ndim == 1:
83+
# non-batched sampling with n_obs
84+
return self.rng.normal(loc=params, scale=self.obs_scale, size=(self.n_obs, params.shape[0]))

bayesflow/simulators/benchmark_simulators/gaussian_linear_uniform.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -79,5 +79,10 @@ def observation_model(self, params: np.ndarray):
7979
# Generate prior predictive samples, possibly a single if n_obs is None
8080
if self.n_obs is None:
8181
return self.rng.normal(loc=params, scale=self.obs_scale)
82-
x = self.rng.normal(loc=params, scale=self.obs_scale, size=(self.n_obs, params.shape[0], params.shape[1]))
83-
return np.transpose(x, (1, 0, 2))
82+
if params.ndim == 2:
83+
# batched sampling with n_obs
84+
x = self.rng.normal(loc=params, scale=self.obs_scale, size=(self.n_obs, params.shape[0], params.shape[1]))
85+
return np.transpose(x, (1, 0, 2))
86+
elif params.ndim == 1:
87+
# non-batched sampling with n_obs
88+
return self.rng.normal(loc=params, scale=self.obs_scale, size=(self.n_obs, params.shape[0]))

bayesflow/simulators/sequential_simulator.py

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010
class SequentialSimulator(Simulator):
1111
"""Combines multiple simulators into one, sequentially."""
1212

13-
def __init__(self, simulators: Sequence[Simulator], expand_outputs: bool = True):
13+
def __init__(self, simulators: Sequence[Simulator], expand_outputs: bool = True, replace_inputs: bool = True):
1414
"""
1515
Initialize a SequentialSimulator.
1616
@@ -22,10 +22,13 @@ def __init__(self, simulators: Sequence[Simulator], expand_outputs: bool = True)
2222
expand_outputs : bool, optional
2323
If True, 1D output arrays are expanded with an additional dimension at the end.
2424
Default is True.
25+
replace_inputs : bool, optional
26+
If True, **kwargs are auto-batched and replace simulator outputs.
2527
"""
2628

2729
self.simulators = simulators
2830
self.expand_outputs = expand_outputs
31+
self.replace_inputs = replace_inputs
2932

3033
@allow_batch_size
3134
def sample(self, batch_shape: Shape, **kwargs) -> dict[str, np.ndarray]:
@@ -53,6 +56,14 @@ def sample(self, batch_shape: Shape, **kwargs) -> dict[str, np.ndarray]:
5356
for simulator in self.simulators:
5457
data |= simulator.sample(batch_shape, **(kwargs | data))
5558

59+
if self.replace_inputs:
60+
common_keys = set(data.keys()) & set(kwargs.keys())
61+
for key in common_keys:
62+
value = kwargs.pop(key)
63+
if isinstance(data[key], np.ndarray):
64+
value = np.broadcast_to(value, data[key].shape)
65+
data[key] = value
66+
5667
if self.expand_outputs:
5768
data = {
5869
key: np.expand_dims(value, axis=-1) if np.ndim(value) == 1 else value for key, value in data.items()

bayesflow/utils/dict_utils.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -233,7 +233,7 @@ def make_variable_array(
233233
else:
234234
raise TypeError(f"Only dicts and tensors are supported as arguments, but your estimates are of type {type(x)}")
235235

236-
if len(variable_names) is not x.shape[-1]:
236+
if len(variable_names) != x.shape[-1]:
237237
raise ValueError("Length of 'variable_names' should be the same as the number of variables.")
238238

239239
if variable_keys is None:

0 commit comments

Comments
 (0)