Skip to content

Commit 1aaa55f

Browse files
committed
fix two moons simulator test
1 parent 30d0107 commit 1aaa55f

File tree

1 file changed

+9
-11
lines changed

1 file changed

+9
-11
lines changed

tests/test_simulators/conftest.py

Lines changed: 9 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -26,20 +26,18 @@ def use_squeezed(request):
2626
def composite_two_moons():
2727
from bayesflow.simulators import make_simulator
2828

29-
def contexts():
30-
r = np.random.normal(0.1, 0.01)
31-
alpha = np.random.uniform(-0.5 * np.pi, 0.5 * np.pi)
32-
return dict(r=r, alpha=alpha)
33-
3429
def parameters():
35-
return dict(theta=np.random.uniform(-1.0, 1.0, size=2))
30+
parameters = np.random.uniform(-1.0, 1.0, size=2)
31+
return dict(parameters=parameters)
3632

37-
def observables(r, alpha, theta):
38-
x1 = -np.abs(theta[0] + theta[1]) / np.sqrt(2.0) + r * np.cos(alpha) + 0.25
39-
x2 = (-theta[0] + theta[1]) / np.sqrt(2.0) + r * np.sin(alpha)
40-
return dict(x=np.stack([x1, x2]))
33+
def observables(parameters):
34+
r = np.random.normal(0.1, 0.01)
35+
alpha = np.random.uniform(-0.5 * np.pi, 0.5 * np.pi)
36+
x1 = -np.abs(parameters[0] + parameters[1]) / np.sqrt(2.0) + r * np.cos(alpha) + 0.25
37+
x2 = (-parameters[0] + parameters[1]) / np.sqrt(2.0) + r * np.sin(alpha)
38+
return dict(observables=np.stack([x1, x2]))
4139

42-
return make_simulator([contexts, parameters, observables])
40+
return make_simulator([parameters, observables])
4341

4442

4543
@pytest.fixture(params=["composite_two_moons", "two_moons"])

0 commit comments

Comments
 (0)