Skip to content

Commit fe63666

Browse files
authored
Minimal Fix for Broken Tests (#130)
* typo * Minimal Fix for Tests
1 parent ccfd11a commit fe63666

File tree

3 files changed

+5
-37
lines changed

3 files changed

+5
-37
lines changed

.github/workflows/docs.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
# From https://github.com/eeholmes/readthedoc-test/blob/main/.github/workflows/docs_pages.yml
22
name: docs
33

4-
# execute this workflow automatically when a we push to master
4+
# execute this workflow automatically when we push to master
55
on:
66
push:
77
branches:

tests/test_sensitivity.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,11 @@
1+
from copy import deepcopy
12
from functools import partial
23
from unittest.mock import MagicMock, Mock
34

45
import numpy as np
56
import pytest
67
import tensorflow as tf
78

8-
from copy import deepcopy
9-
109
from bayesflow import computational_utilities, sensitivity, simulation
1110
from tests.test_trainers import _create_training_setup, _prior, _simulator
1211

@@ -26,7 +25,7 @@ def _trainer_amortizer_sample_mock(input_dict, n_samples):
2625

2726

2827
class TestMisspecificationExperiment:
29-
def setup(self):
28+
def setup_method(self):
3029
self.trainer = _create_training_setup(mode="posterior")
3130

3231
# Mock the approximate posterior sampling of the amortizer, return np.ones of shape (n_sim, n_samples)

tests/test_trainers.py

Lines changed: 2 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -68,8 +68,9 @@ def _create_training_setup(mode):
6868
trainer = Trainer(generative_model=model, amortizer=amortizer)
6969
return trainer
7070

71+
7172
class TestTrainer:
72-
def setup(self):
73+
def setup_method(self):
7374
trainer_posterior = _create_training_setup("posterior")
7475
trainer_likelihood = _create_training_setup("likelihood")
7576
trainer_joint = _create_training_setup("joint")
@@ -111,7 +112,6 @@ def test_train_online(self, mode, reuse_optimizer, validation_sims):
111112
assert type(h["train_losses"]) is DataFrame
112113
assert type(h["val_losses"]) is DataFrame
113114

114-
115115
@pytest.mark.parametrize("mode", ["posterior", "joint"])
116116
@pytest.mark.parametrize("reuse_optimizer", [True, False])
117117
@pytest.mark.parametrize("validation_sims", [20, None])
@@ -202,34 +202,3 @@ def test_train_rounds(self, mode, reuse_optimizer, validation_sims):
202202
assert type(h) is dict
203203
assert type(h["train_losses"]) is DataFrame
204204
assert type(h["val_losses"]) is DataFrame
205-
206-
@pytest.mark.parametrize("reference_data", [None, "dict", "numpy"])
207-
@pytest.mark.parametrize("observed_data_type", ["dict", "numpy"])
208-
@pytest.mark.parametrize("bootstrap", [True, False])
209-
def mmd_hypothesis_test_no_reference(self, reference_data, observed_data_type, bootstrap):
210-
trainer = self.trainers["posterior"]
211-
_ = trainer.train_online(epochs=1, iterations_per_epoch=1, batch_size=4)
212-
213-
num_reference_simulations = 10
214-
num_observed_simulations = 2
215-
num_null_samples = 5
216-
217-
if reference_data is None:
218-
if reference_data == "dict":
219-
reference_data = trainer.configurator(trainer.generative_model(num_reference_simulations))
220-
elif reference_data == "numpy":
221-
reference_data = trainer.configurator(trainer.generative_model(num_reference_simulations))['summary_conditions']
222-
223-
if observed_data_type == "dict":
224-
observed_data = trainer.configurator(trainer.generative_model(num_observed_simulations))
225-
elif observed_data_type == "numpy":
226-
observed_data = trainer.configurator(trainer.generative_model(num_observed_simulations))['summary_conditions']
227-
228-
MMD_sampling_distribution, MMD_observed = trainer.mmd_hypothesis_test(observed_data=observed_data,
229-
reference_data=reference_data,
230-
num_reference_simulations=num_reference_simulations,
231-
num_null_samples=num_null_samples,
232-
bootstrap=bootstrap)
233-
234-
assert MMD_sampling_distribution.shape[0] == num_reference_simulations
235-
assert np.all(MMD_sampling_distribution > 0)

0 commit comments

Comments
 (0)