Skip to content

Commit c36bc3b

Browse files
authored
fix gaussian linear benchmark, expose benchmark simulators. (#414)
1 parent 12921dc commit c36bc3b

File tree

6 files changed

+163
-11
lines changed

6 files changed

+163
-11
lines changed

bayesflow/simulators/__init__.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,8 +12,16 @@
1212
from .simulator import Simulator
1313

1414
from .benchmark_simulators import (
15+
BernoulliGLM,
16+
BernoulliGLMRaw,
17+
GaussianLinear,
18+
GaussianLinearUniform,
19+
GaussianMixture,
20+
InverseKinematics,
1521
LotkaVolterra,
1622
SIR,
23+
SLCP,
24+
SLCPDistractors,
1725
TwoMoons,
1826
)
1927

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,11 @@
1+
from .bernoulli_glm import BernoulliGLM
2+
from .bernoulli_glm_raw import BernoulliGLMRaw
3+
from .gaussian_linear import GaussianLinear
4+
from .gaussian_linear_uniform import GaussianLinearUniform
5+
from .gaussian_mixture import GaussianMixture
6+
from .inverse_kinematics import InverseKinematics
17
from .lotka_volterra import LotkaVolterra
28
from .sir import SIR
9+
from .slcp import SLCP
10+
from .slcp_distractors import SLCPDistractors
311
from .two_moons import TwoMoons

bayesflow/simulators/benchmark_simulators/gaussian_linear.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -75,5 +75,10 @@ def observation_model(self, params: np.ndarray):
7575
# Generate prior predictive samples, possibly a single if n_obs is None
7676
if self.n_obs is None:
7777
return self.rng.normal(loc=params, scale=self.obs_scale)
78-
x = self.rng.normal(loc=params, scale=self.obs_scale, size=(self.n_obs, params.shape[0], params.shape[1]))
79-
return np.transpose(x, (1, 0, 2))
78+
if params.ndim == 2:
79+
# batched sampling with n_obs
80+
x = self.rng.normal(loc=params, scale=self.obs_scale, size=(self.n_obs, params.shape[0], params.shape[1]))
81+
return np.transpose(x, (1, 0, 2))
82+
elif params.ndim == 1:
83+
# non-batched sampling with n_obs
84+
return self.rng.normal(loc=params, scale=self.obs_scale, size=(self.n_obs, params.shape[0]))

bayesflow/simulators/benchmark_simulators/gaussian_linear_uniform.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -79,5 +79,10 @@ def observation_model(self, params: np.ndarray):
7979
# Generate prior predictive samples, possibly a single if n_obs is None
8080
if self.n_obs is None:
8181
return self.rng.normal(loc=params, scale=self.obs_scale)
82-
x = self.rng.normal(loc=params, scale=self.obs_scale, size=(self.n_obs, params.shape[0], params.shape[1]))
83-
return np.transpose(x, (1, 0, 2))
82+
if params.ndim == 2:
83+
# batched sampling with n_obs
84+
x = self.rng.normal(loc=params, scale=self.obs_scale, size=(self.n_obs, params.shape[0], params.shape[1]))
85+
return np.transpose(x, (1, 0, 2))
86+
elif params.ndim == 1:
87+
# non-batched sampling with n_obs
88+
return self.rng.normal(loc=params, scale=self.obs_scale, size=(self.n_obs, params.shape[0]))

tests/test_simulators/conftest.py

Lines changed: 123 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,97 @@ def use_squeezed(request):
2222
return request.param
2323

2424

25+
@pytest.fixture()
26+
def bernoulli_glm():
27+
from bayesflow.simulators import BernoulliGLM
28+
29+
return BernoulliGLM()
30+
31+
32+
@pytest.fixture()
33+
def bernoulli_glm_raw():
34+
from bayesflow.simulators import BernoulliGLMRaw
35+
36+
return BernoulliGLMRaw()
37+
38+
39+
@pytest.fixture()
40+
def gaussian_linear():
41+
from bayesflow.simulators import GaussianLinear
42+
43+
return GaussianLinear()
44+
45+
46+
@pytest.fixture()
47+
def gaussian_linear_n_obs():
48+
from bayesflow.simulators import GaussianLinear
49+
50+
return GaussianLinear(n_obs=5)
51+
52+
53+
@pytest.fixture()
54+
def gaussian_linear_uniform():
55+
from bayesflow.simulators import GaussianLinearUniform
56+
57+
return GaussianLinearUniform()
58+
59+
60+
@pytest.fixture()
61+
def gaussian_linear_uniform_n_obs():
62+
from bayesflow.simulators import GaussianLinearUniform
63+
64+
return GaussianLinearUniform(n_obs=5)
65+
66+
67+
@pytest.fixture(
68+
params=["gaussian_linear", "gaussian_linear_n_obs", "gaussian_linear_uniform", "gaussian_linear_uniform_n_obs"]
69+
)
70+
def gaussian_linear_simulator(request):
71+
return request.getfixturevalue(request.param)
72+
73+
74+
@pytest.fixture()
75+
def gaussian_mixture():
76+
from bayesflow.simulators import GaussianMixture
77+
78+
return GaussianMixture()
79+
80+
81+
@pytest.fixture()
82+
def inverse_kinematics():
83+
from bayesflow.simulators import InverseKinematics
84+
85+
return InverseKinematics()
86+
87+
88+
@pytest.fixture()
89+
def lotka_volterra():
90+
from bayesflow.simulators import LotkaVolterra
91+
92+
return LotkaVolterra()
93+
94+
95+
@pytest.fixture()
96+
def sir():
97+
from bayesflow.simulators import SIR
98+
99+
return SIR()
100+
101+
102+
@pytest.fixture()
103+
def slcp():
104+
from bayesflow.simulators import SLCP
105+
106+
return SLCP()
107+
108+
109+
@pytest.fixture()
110+
def slcp_distractors():
111+
from bayesflow.simulators import SLCPDistractors
112+
113+
return SLCPDistractors()
114+
115+
25116
@pytest.fixture()
26117
def composite_two_moons():
27118
from bayesflow.simulators import make_simulator
@@ -40,13 +131,40 @@ def observables(parameters):
40131
return make_simulator([parameters, observables])
41132

42133

43-
@pytest.fixture(params=["composite_two_moons", "two_moons"])
44-
def simulator(request):
45-
return request.getfixturevalue(request.param)
46-
47-
48134
@pytest.fixture()
49135
def two_moons():
50136
from bayesflow.simulators import TwoMoons
51137

52138
return TwoMoons()
139+
140+
141+
@pytest.fixture(
142+
params=[
143+
"composite_two_moons",
144+
"two_moons",
145+
]
146+
)
147+
def two_moons_simulator(request):
148+
return request.getfixturevalue(request.param)
149+
150+
151+
@pytest.fixture(
152+
params=[
153+
"bernoulli_glm",
154+
"bernoulli_glm_raw",
155+
"gaussian_linear",
156+
"gaussian_linear_n_obs",
157+
"gaussian_linear_uniform",
158+
"gaussian_linear_uniform_n_obs",
159+
"gaussian_mixture",
160+
"inverse_kinematics",
161+
"lotka_volterra",
162+
"sir",
163+
"slcp",
164+
"slcp_distractors",
165+
"composite_two_moons",
166+
"two_moons",
167+
]
168+
)
169+
def simulator(request):
170+
return request.getfixturevalue(request.param)

tests/test_simulators/test_simulators.py

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,8 +2,8 @@
22
import numpy as np
33

44

5-
def test_two_moons(simulator, batch_size):
6-
samples = simulator.sample((batch_size,))
5+
def test_two_moons(two_moons_simulator, batch_size):
6+
samples = two_moons_simulator.sample((batch_size,))
77

88
assert isinstance(samples, dict)
99
assert list(samples.keys()) == ["parameters", "observables"]
@@ -13,6 +13,14 @@ def test_two_moons(simulator, batch_size):
1313
assert samples["observables"].shape == (batch_size, 2)
1414

1515

16+
def test_gaussian_linear(gaussian_linear_simulator, batch_size):
17+
samples = gaussian_linear_simulator.sample((batch_size,))
18+
19+
# test n_obs respected if applicable
20+
if hasattr(gaussian_linear_simulator, "n_obs") and isinstance(gaussian_linear_simulator.n_obs, int):
21+
assert samples["observables"].shape[1] == gaussian_linear_simulator.n_obs
22+
23+
1624
def test_sample(simulator, batch_size):
1725
samples = simulator.sample((batch_size,))
1826

0 commit comments

Comments
 (0)