@@ -68,8 +68,9 @@ def _create_training_setup(mode):
6868 trainer = Trainer (generative_model = model , amortizer = amortizer )
6969 return trainer
7070
71+
7172class TestTrainer :
72- def setup (self ):
73+ def setup_method (self ):
7374 trainer_posterior = _create_training_setup ("posterior" )
7475 trainer_likelihood = _create_training_setup ("likelihood" )
7576 trainer_joint = _create_training_setup ("joint" )
@@ -111,7 +112,6 @@ def test_train_online(self, mode, reuse_optimizer, validation_sims):
111112 assert type (h ["train_losses" ]) is DataFrame
112113 assert type (h ["val_losses" ]) is DataFrame
113114
114-
115115 @pytest .mark .parametrize ("mode" , ["posterior" , "joint" ])
116116 @pytest .mark .parametrize ("reuse_optimizer" , [True , False ])
117117 @pytest .mark .parametrize ("validation_sims" , [20 , None ])
@@ -202,34 +202,3 @@ def test_train_rounds(self, mode, reuse_optimizer, validation_sims):
202202 assert type (h ) is dict
203203 assert type (h ["train_losses" ]) is DataFrame
204204 assert type (h ["val_losses" ]) is DataFrame
205-
206- @pytest .mark .parametrize ("reference_data" , [None , "dict" , "numpy" ])
207- @pytest .mark .parametrize ("observed_data_type" , ["dict" , "numpy" ])
208- @pytest .mark .parametrize ("bootstrap" , [True , False ])
209- def mmd_hypothesis_test_no_reference (self , reference_data , observed_data_type , bootstrap ):
210- trainer = self .trainers ["posterior" ]
211- _ = trainer .train_online (epochs = 1 , iterations_per_epoch = 1 , batch_size = 4 )
212-
213- num_reference_simulations = 10
214- num_observed_simulations = 2
215- num_null_samples = 5
216-
217- if reference_data is None :
218- if reference_data == "dict" :
219- reference_data = trainer .configurator (trainer .generative_model (num_reference_simulations ))
220- elif reference_data == "numpy" :
221- reference_data = trainer .configurator (trainer .generative_model (num_reference_simulations ))['summary_conditions' ]
222-
223- if observed_data_type == "dict" :
224- observed_data = trainer .configurator (trainer .generative_model (num_observed_simulations ))
225- elif observed_data_type == "numpy" :
226- observed_data = trainer .configurator (trainer .generative_model (num_observed_simulations ))['summary_conditions' ]
227-
228- MMD_sampling_distribution , MMD_observed = trainer .mmd_hypothesis_test (observed_data = observed_data ,
229- reference_data = reference_data ,
230- num_reference_simulations = num_reference_simulations ,
231- num_null_samples = num_null_samples ,
232- bootstrap = bootstrap )
233-
234- assert MMD_sampling_distribution .shape [0 ] == num_reference_simulations
235- assert np .all (MMD_sampling_distribution > 0 )
0 commit comments