Skip to content

Commit 47c0162

Browse files
committed
Merge branch 'Development' of https://github.com/stefanradev93/BayesFlow into Development
2 parents 927b881 + 14e9de3 commit 47c0162

File tree

1 file changed

+165
-123
lines changed

1 file changed

+165
-123
lines changed

tests/test_trainers.py

Lines changed: 165 additions & 123 deletions
Original file line numberDiff line numberDiff line change
@@ -68,126 +68,168 @@ def _create_training_setup(mode):
6868
trainer = Trainer(generative_model=model, amortizer=amortizer)
6969
return trainer
7070

71-
72-
@pytest.mark.parametrize("mode", ["posterior", "likelihood"])
73-
@pytest.mark.parametrize("reuse_optimizer", [True, False])
74-
@pytest.mark.parametrize("validation_sims", [20, None])
75-
def test_train_online(mode, reuse_optimizer, validation_sims):
76-
"""Tests the online training functionality."""
77-
78-
# Create trainer and train online
79-
trainer = _create_training_setup(mode)
80-
h = trainer.train_online(
81-
epochs=2,
82-
iterations_per_epoch=3,
83-
batch_size=8,
84-
use_autograph=False,
85-
reuse_optimizer=reuse_optimizer,
86-
validation_sims=validation_sims,
87-
)
88-
89-
# Assert (non)-existence of optimizer
90-
if reuse_optimizer:
91-
assert trainer.optimizer is not None
92-
else:
93-
assert trainer.optimizer is None
94-
95-
# Ensure losses were stored in the correct format
96-
if validation_sims is None:
97-
assert type(h) is DataFrame
98-
else:
99-
assert type(h) is dict
100-
assert type(h["train_losses"]) is DataFrame
101-
assert type(h["val_losses"]) is DataFrame
102-
103-
104-
@pytest.mark.parametrize("mode", ["posterior", "joint"])
105-
@pytest.mark.parametrize("reuse_optimizer", [True, False])
106-
@pytest.mark.parametrize("validation_sims", [20, None])
107-
def test_train_experience_replay(mode, reuse_optimizer, validation_sims):
108-
"""Tests the experience replay training functionality."""
109-
110-
# Create trainer and train with experience replay
111-
trainer = _create_training_setup(mode)
112-
h = trainer.train_experience_replay(
113-
epochs=3, iterations_per_epoch=4, batch_size=8, validation_sims=validation_sims, reuse_optimizer=reuse_optimizer
114-
)
115-
116-
# Assert (non)-existence of optimizer
117-
if reuse_optimizer:
118-
assert trainer.optimizer is not None
119-
else:
120-
assert trainer.optimizer is None
121-
122-
# Ensure losses were stored in the correct format
123-
if validation_sims is None:
124-
assert type(h) is DataFrame
125-
else:
126-
assert type(h) is dict
127-
assert type(h["train_losses"]) is DataFrame
128-
assert type(h["val_losses"]) is DataFrame
129-
130-
131-
@pytest.mark.parametrize("mode", ["likelihood", "joint"])
132-
@pytest.mark.parametrize("reuse_optimizer", [True, False])
133-
@pytest.mark.parametrize("validation_sims", [20, None])
134-
def test_train_offline(mode, reuse_optimizer, validation_sims):
135-
"""Tests the offline training functionality."""
136-
137-
# Create trainer and data and train offline
138-
trainer = _create_training_setup(mode)
139-
simulations = trainer.generative_model(100)
140-
h = trainer.train_offline(
141-
simulations_dict=simulations,
142-
epochs=2,
143-
batch_size=16,
144-
use_autograph=True,
145-
validation_sims=validation_sims,
146-
reuse_optimizer=reuse_optimizer,
147-
)
148-
149-
# Assert (non)-existence of optimizer
150-
if reuse_optimizer:
151-
assert trainer.optimizer is not None
152-
else:
153-
assert trainer.optimizer is None
154-
155-
# Ensure losses were stored in the correct format
156-
if validation_sims is None:
157-
assert type(h) is DataFrame
158-
else:
159-
assert type(h) is dict
160-
assert type(h["train_losses"]) is DataFrame
161-
assert type(h["val_losses"]) is DataFrame
162-
163-
164-
@pytest.mark.parametrize("mode", ["likelihood", "posterior"])
165-
@pytest.mark.parametrize("reuse_optimizer", [True, False])
166-
@pytest.mark.parametrize("validation_sims", [20, None])
167-
def test_train_rounds(mode, reuse_optimizer, validation_sims):
168-
"""Tests the offline training functionality."""
169-
170-
# Create trainer and data and train offline
171-
trainer = _create_training_setup(mode)
172-
h = trainer.train_rounds(
173-
rounds=2,
174-
sim_per_round=32,
175-
epochs=2,
176-
batch_size=8,
177-
validation_sims=validation_sims,
178-
reuse_optimizer=reuse_optimizer,
179-
)
180-
181-
# Assert (non)-existence of optimizer
182-
if reuse_optimizer:
183-
assert trainer.optimizer is not None
184-
else:
185-
assert trainer.optimizer is None
186-
187-
# Ensure losses were stored in the correct format
188-
if validation_sims is None:
189-
assert type(h) is DataFrame
190-
else:
191-
assert type(h) is dict
192-
assert type(h["train_losses"]) is DataFrame
193-
assert type(h["val_losses"]) is DataFrame
71+
class TestTrainer:
72+
def setup(self):
73+
trainer_posterior = _create_training_setup("posterior")
74+
trainer_likelihood = _create_training_setup("likelihood")
75+
trainer_joint = _create_training_setup("joint")
76+
self.trainers = {
77+
"posterior": trainer_posterior,
78+
"likelihood": trainer_likelihood,
79+
"joint": trainer_joint
80+
}
81+
82+
83+
@pytest.mark.parametrize("mode", ["posterior", "likelihood"])
84+
@pytest.mark.parametrize("reuse_optimizer", [True, False])
85+
@pytest.mark.parametrize("validation_sims", [20, None])
86+
def test_train_online(self, mode, reuse_optimizer, validation_sims):
87+
"""Tests the online training functionality."""
88+
89+
# Create trainer and train online
90+
trainer = self.trainers[mode]
91+
h = trainer.train_online(
92+
epochs=2,
93+
iterations_per_epoch=3,
94+
batch_size=8,
95+
use_autograph=False,
96+
reuse_optimizer=reuse_optimizer,
97+
validation_sims=validation_sims,
98+
)
99+
100+
# Assert (non)-existence of optimizer
101+
if reuse_optimizer:
102+
assert trainer.optimizer is not None
103+
else:
104+
assert trainer.optimizer is None
105+
106+
# Ensure losses were stored in the correct format
107+
if validation_sims is None:
108+
assert type(h) is DataFrame
109+
else:
110+
assert type(h) is dict
111+
assert type(h["train_losses"]) is DataFrame
112+
assert type(h["val_losses"]) is DataFrame
113+
114+
115+
@pytest.mark.parametrize("mode", ["posterior", "joint"])
116+
@pytest.mark.parametrize("reuse_optimizer", [True, False])
117+
@pytest.mark.parametrize("validation_sims", [20, None])
118+
def test_train_experience_replay(self, mode, reuse_optimizer, validation_sims):
119+
"""Tests the experience replay training functionality."""
120+
121+
# Create trainer and train with experience replay
122+
trainer = self.trainers[mode]
123+
h = trainer.train_experience_replay(
124+
epochs=3, iterations_per_epoch=4, batch_size=8, validation_sims=validation_sims, reuse_optimizer=reuse_optimizer
125+
)
126+
127+
# Assert (non)-existence of optimizer
128+
if reuse_optimizer:
129+
assert trainer.optimizer is not None
130+
else:
131+
assert trainer.optimizer is None
132+
133+
# Ensure losses were stored in the correct format
134+
if validation_sims is None:
135+
assert type(h) is DataFrame
136+
else:
137+
assert type(h) is dict
138+
assert type(h["train_losses"]) is DataFrame
139+
assert type(h["val_losses"]) is DataFrame
140+
141+
142+
@pytest.mark.parametrize("mode", ["likelihood", "joint"])
143+
@pytest.mark.parametrize("reuse_optimizer", [True, False])
144+
@pytest.mark.parametrize("validation_sims", [20, None])
145+
def test_train_offline(self, mode, reuse_optimizer, validation_sims):
146+
"""Tests the offline training functionality."""
147+
148+
# Create trainer and data and train offline
149+
trainer = self.trainers[mode]
150+
simulations = trainer.generative_model(100)
151+
h = trainer.train_offline(
152+
simulations_dict=simulations,
153+
epochs=2,
154+
batch_size=16,
155+
use_autograph=True,
156+
validation_sims=validation_sims,
157+
reuse_optimizer=reuse_optimizer,
158+
)
159+
160+
# Assert (non)-existence of optimizer
161+
if reuse_optimizer:
162+
assert trainer.optimizer is not None
163+
else:
164+
assert trainer.optimizer is None
165+
166+
# Ensure losses were stored in the correct format
167+
if validation_sims is None:
168+
assert type(h) is DataFrame
169+
else:
170+
assert type(h) is dict
171+
assert type(h["train_losses"]) is DataFrame
172+
assert type(h["val_losses"]) is DataFrame
173+
174+
175+
@pytest.mark.parametrize("mode", ["likelihood", "posterior"])
176+
@pytest.mark.parametrize("reuse_optimizer", [True, False])
177+
@pytest.mark.parametrize("validation_sims", [20, None])
178+
def test_train_rounds(self, mode, reuse_optimizer, validation_sims):
179+
"""Tests the offline training functionality."""
180+
181+
# Create trainer and data and train offline
182+
trainer = self.trainers[mode]
183+
h = trainer.train_rounds(
184+
rounds=2,
185+
sim_per_round=32,
186+
epochs=2,
187+
batch_size=8,
188+
validation_sims=validation_sims,
189+
reuse_optimizer=reuse_optimizer,
190+
)
191+
192+
# Assert (non)-existence of optimizer
193+
if reuse_optimizer:
194+
assert trainer.optimizer is not None
195+
else:
196+
assert trainer.optimizer is None
197+
198+
# Ensure losses were stored in the correct format
199+
if validation_sims is None:
200+
assert type(h) is DataFrame
201+
else:
202+
assert type(h) is dict
203+
assert type(h["train_losses"]) is DataFrame
204+
assert type(h["val_losses"]) is DataFrame
205+
206+
@pytest.mark.parametrize("reference_data", [None, "dict", "numpy"])
207+
@pytest.mark.parametrize("observed_data_type", ["dict", "numpy"])
208+
@pytest.mark.parametrize("bootstrap", [True, False])
209+
def mmd_hypothesis_test_no_reference(self, reference_data, observed_data_type, bootstrap):
210+
trainer = self.trainers["posterior"]
211+
_ = trainer.train_online(epochs=1, iterations_per_epoch=1, batch_size=4)
212+
213+
num_reference_simulations = 10
214+
num_observed_simulations = 2
215+
num_null_samples = 5
216+
217+
if reference_data is None:
218+
if reference_data == "dict":
219+
reference_data = trainer.configurator(trainer.generative_model(num_reference_simulations))
220+
elif reference_data == "numpy":
221+
reference_data = trainer.configurator(trainer.generative_model(num_reference_simulations))['summary_conditions']
222+
223+
if observed_data_type == "dict":
224+
observed_data = trainer.configurator(trainer.generative_model(num_observed_simulations))
225+
elif observed_data_type == "numpy":
226+
observed_data = trainer.configurator(trainer.generative_model(num_observed_simulations))['summary_conditions']
227+
228+
MMD_sampling_distribution, MMD_observed = trainer.mmd_hypothesis_test(observed_data=observed_data,
229+
reference_data=reference_data,
230+
num_reference_simulations=num_reference_simulations,
231+
num_null_samples=num_null_samples,
232+
bootstrap=bootstrap)
233+
234+
assert MMD_sampling_distribution.shape[0] == num_reference_simulations
235+
assert np.all(MMD_sampling_distribution > 0)

0 commit comments

Comments
 (0)