Skip to content

Commit be44a77

Browse files
committed
adapt tests for torch
1 parent da3bacb commit be44a77

File tree

4 files changed

+26
-16
lines changed

4 files changed

+26
-16
lines changed

tests/test_compatibility/test_approximators/test_point_approximator/conftest.py

Lines changed: 10 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -6,11 +6,6 @@ def batch_size():
66
return 8
77

88

9-
@pytest.fixture()
10-
def num_samples():
11-
return 100
12-
13-
149
@pytest.fixture(params=["single_parametric", "multiple_parametric"])
1510
def point_inference_network(request):
1611
match request.param:
@@ -52,3 +47,13 @@ def approximator(adapter, point_inference_network, summary_network, standardize)
5247
summary_network=summary_network,
5348
standardize=standardize,
5449
)
50+
51+
52+
@pytest.fixture()
53+
def adapter():
54+
from bayesflow import ContinuousApproximator
55+
56+
return ContinuousApproximator.build_adapter(
57+
inference_variables=["mean", "std"],
58+
inference_conditions=["x"],
59+
)

tests/test_compatibility/test_approximators/test_point_approximator/test_point_approximator.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -5,10 +5,9 @@
55

66

77
@pytest.mark.parametrize(
8-
"summary_network,simulator,adapter,standardize",
8+
"summary_network,simulator,standardize",
99
[
10-
["deep_set", "sir", "summary", "all"], # use deep_set for speed
11-
[None, "two_moons", "direct", None],
10+
[None, "normal", "all"],
1211
],
1312
indirect=True,
1413
)

tests/test_compatibility/test_distributions/test_distributions.py

Lines changed: 11 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -48,12 +48,20 @@ def from_config(cls, config, custom_objects=None):
4848

4949
model = DummyModel(distribution)
5050
model.compile(loss=keras.losses.MeanSquaredError())
51-
model.fit(
52-
random_samples,
53-
keras.ops.ones(keras.ops.shape(random_samples)[:-1]),
51+
fit_kwargs = dict(
52+
x=random_samples,
53+
y=keras.ops.ones(keras.ops.shape(random_samples)[:-1]),
5454
batch_size=keras.ops.shape(random_samples)[0],
5555
epochs=1,
5656
)
57+
if keras.backend.backend() == "torch":
58+
import torch
59+
60+
with torch.enable_grad():
61+
model.fit(**fit_kwargs)
62+
else:
63+
model.fit(**fit_kwargs)
64+
5765
model.save(filepaths["model"])
5866

5967
output = self.evaluate(model.distribution, random_samples)

tests/test_compatibility/test_metrics/conftest.py

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
import pytest
2-
import numpy as np
2+
import keras
33

44

55
@pytest.fixture()
@@ -20,11 +20,9 @@ def metric(request):
2020

2121
@pytest.fixture
2222
def samples_1():
23-
rng = np.random.default_rng(seed=1)
24-
return rng.normal(size=(2, 3)).astype(np.float32)
23+
return keras.random.normal((2, 3), seed=1)
2524

2625

2726
@pytest.fixture
2827
def samples_2():
29-
rng = np.random.default_rng(seed=2)
30-
return rng.normal(size=(2, 3)).astype(np.float32)
28+
return keras.random.normal((2, 3), seed=2)

0 commit comments

Comments
 (0)