@@ -68,126 +68,168 @@ def _create_training_setup(mode):
6868 trainer = Trainer (generative_model = model , amortizer = amortizer )
6969 return trainer
7070
71-
72- @pytest .mark .parametrize ("mode" , ["posterior" , "likelihood" ])
73- @pytest .mark .parametrize ("reuse_optimizer" , [True , False ])
74- @pytest .mark .parametrize ("validation_sims" , [20 , None ])
75- def test_train_online (mode , reuse_optimizer , validation_sims ):
76- """Tests the online training functionality."""
77-
78- # Create trainer and train online
79- trainer = _create_training_setup (mode )
80- h = trainer .train_online (
81- epochs = 2 ,
82- iterations_per_epoch = 3 ,
83- batch_size = 8 ,
84- use_autograph = False ,
85- reuse_optimizer = reuse_optimizer ,
86- validation_sims = validation_sims ,
87- )
88-
89- # Assert (non)-existence of optimizer
90- if reuse_optimizer :
91- assert trainer .optimizer is not None
92- else :
93- assert trainer .optimizer is None
94-
95- # Ensure losses were stored in the correct format
96- if validation_sims is None :
97- assert type (h ) is DataFrame
98- else :
99- assert type (h ) is dict
100- assert type (h ["train_losses" ]) is DataFrame
101- assert type (h ["val_losses" ]) is DataFrame
102-
103-
104- @pytest .mark .parametrize ("mode" , ["posterior" , "joint" ])
105- @pytest .mark .parametrize ("reuse_optimizer" , [True , False ])
106- @pytest .mark .parametrize ("validation_sims" , [20 , None ])
107- def test_train_experience_replay (mode , reuse_optimizer , validation_sims ):
108- """Tests the experience replay training functionality."""
109-
110- # Create trainer and train with experience replay
111- trainer = _create_training_setup (mode )
112- h = trainer .train_experience_replay (
113- epochs = 3 , iterations_per_epoch = 4 , batch_size = 8 , validation_sims = validation_sims , reuse_optimizer = reuse_optimizer
114- )
115-
116- # Assert (non)-existence of optimizer
117- if reuse_optimizer :
118- assert trainer .optimizer is not None
119- else :
120- assert trainer .optimizer is None
121-
122- # Ensure losses were stored in the correct format
123- if validation_sims is None :
124- assert type (h ) is DataFrame
125- else :
126- assert type (h ) is dict
127- assert type (h ["train_losses" ]) is DataFrame
128- assert type (h ["val_losses" ]) is DataFrame
129-
130-
131- @pytest .mark .parametrize ("mode" , ["likelihood" , "joint" ])
132- @pytest .mark .parametrize ("reuse_optimizer" , [True , False ])
133- @pytest .mark .parametrize ("validation_sims" , [20 , None ])
134- def test_train_offline (mode , reuse_optimizer , validation_sims ):
135- """Tests the offline training functionality."""
136-
137- # Create trainer and data and train offline
138- trainer = _create_training_setup (mode )
139- simulations = trainer .generative_model (100 )
140- h = trainer .train_offline (
141- simulations_dict = simulations ,
142- epochs = 2 ,
143- batch_size = 16 ,
144- use_autograph = True ,
145- validation_sims = validation_sims ,
146- reuse_optimizer = reuse_optimizer ,
147- )
148-
149- # Assert (non)-existence of optimizer
150- if reuse_optimizer :
151- assert trainer .optimizer is not None
152- else :
153- assert trainer .optimizer is None
154-
155- # Ensure losses were stored in the correct format
156- if validation_sims is None :
157- assert type (h ) is DataFrame
158- else :
159- assert type (h ) is dict
160- assert type (h ["train_losses" ]) is DataFrame
161- assert type (h ["val_losses" ]) is DataFrame
162-
163-
164- @pytest .mark .parametrize ("mode" , ["likelihood" , "posterior" ])
165- @pytest .mark .parametrize ("reuse_optimizer" , [True , False ])
166- @pytest .mark .parametrize ("validation_sims" , [20 , None ])
167- def test_train_rounds (mode , reuse_optimizer , validation_sims ):
168- """Tests the offline training functionality."""
169-
170- # Create trainer and data and train offline
171- trainer = _create_training_setup (mode )
172- h = trainer .train_rounds (
173- rounds = 2 ,
174- sim_per_round = 32 ,
175- epochs = 2 ,
176- batch_size = 8 ,
177- validation_sims = validation_sims ,
178- reuse_optimizer = reuse_optimizer ,
179- )
180-
181- # Assert (non)-existence of optimizer
182- if reuse_optimizer :
183- assert trainer .optimizer is not None
184- else :
185- assert trainer .optimizer is None
186-
187- # Ensure losses were stored in the correct format
188- if validation_sims is None :
189- assert type (h ) is DataFrame
190- else :
191- assert type (h ) is dict
192- assert type (h ["train_losses" ]) is DataFrame
193- assert type (h ["val_losses" ]) is DataFrame
71+ class TestTrainer :
72+ def setup (self ):
73+ trainer_posterior = _create_training_setup ("posterior" )
74+ trainer_likelihood = _create_training_setup ("likelihood" )
75+ trainer_joint = _create_training_setup ("joint" )
76+ self .trainers = {
77+ "posterior" : trainer_posterior ,
78+ "likelihood" : trainer_likelihood ,
79+ "joint" : trainer_joint
80+ }
81+
82+
83+ @pytest .mark .parametrize ("mode" , ["posterior" , "likelihood" ])
84+ @pytest .mark .parametrize ("reuse_optimizer" , [True , False ])
85+ @pytest .mark .parametrize ("validation_sims" , [20 , None ])
86+ def test_train_online (self , mode , reuse_optimizer , validation_sims ):
87+ """Tests the online training functionality."""
88+
89+ # Create trainer and train online
90+ trainer = self .trainers [mode ]
91+ h = trainer .train_online (
92+ epochs = 2 ,
93+ iterations_per_epoch = 3 ,
94+ batch_size = 8 ,
95+ use_autograph = False ,
96+ reuse_optimizer = reuse_optimizer ,
97+ validation_sims = validation_sims ,
98+ )
99+
100+ # Assert (non)-existence of optimizer
101+ if reuse_optimizer :
102+ assert trainer .optimizer is not None
103+ else :
104+ assert trainer .optimizer is None
105+
106+ # Ensure losses were stored in the correct format
107+ if validation_sims is None :
108+ assert type (h ) is DataFrame
109+ else :
110+ assert type (h ) is dict
111+ assert type (h ["train_losses" ]) is DataFrame
112+ assert type (h ["val_losses" ]) is DataFrame
113+
114+
115+ @pytest .mark .parametrize ("mode" , ["posterior" , "joint" ])
116+ @pytest .mark .parametrize ("reuse_optimizer" , [True , False ])
117+ @pytest .mark .parametrize ("validation_sims" , [20 , None ])
118+ def test_train_experience_replay (self , mode , reuse_optimizer , validation_sims ):
119+ """Tests the experience replay training functionality."""
120+
121+ # Create trainer and train with experience replay
122+ trainer = self .trainers [mode ]
123+ h = trainer .train_experience_replay (
124+ epochs = 3 , iterations_per_epoch = 4 , batch_size = 8 , validation_sims = validation_sims , reuse_optimizer = reuse_optimizer
125+ )
126+
127+ # Assert (non)-existence of optimizer
128+ if reuse_optimizer :
129+ assert trainer .optimizer is not None
130+ else :
131+ assert trainer .optimizer is None
132+
133+ # Ensure losses were stored in the correct format
134+ if validation_sims is None :
135+ assert type (h ) is DataFrame
136+ else :
137+ assert type (h ) is dict
138+ assert type (h ["train_losses" ]) is DataFrame
139+ assert type (h ["val_losses" ]) is DataFrame
140+
141+
142+ @pytest .mark .parametrize ("mode" , ["likelihood" , "joint" ])
143+ @pytest .mark .parametrize ("reuse_optimizer" , [True , False ])
144+ @pytest .mark .parametrize ("validation_sims" , [20 , None ])
145+ def test_train_offline (self , mode , reuse_optimizer , validation_sims ):
146+ """Tests the offline training functionality."""
147+
148+ # Create trainer and data and train offline
149+ trainer = self .trainers [mode ]
150+ simulations = trainer .generative_model (100 )
151+ h = trainer .train_offline (
152+ simulations_dict = simulations ,
153+ epochs = 2 ,
154+ batch_size = 16 ,
155+ use_autograph = True ,
156+ validation_sims = validation_sims ,
157+ reuse_optimizer = reuse_optimizer ,
158+ )
159+
160+ # Assert (non)-existence of optimizer
161+ if reuse_optimizer :
162+ assert trainer .optimizer is not None
163+ else :
164+ assert trainer .optimizer is None
165+
166+ # Ensure losses were stored in the correct format
167+ if validation_sims is None :
168+ assert type (h ) is DataFrame
169+ else :
170+ assert type (h ) is dict
171+ assert type (h ["train_losses" ]) is DataFrame
172+ assert type (h ["val_losses" ]) is DataFrame
173+
174+
175+ @pytest .mark .parametrize ("mode" , ["likelihood" , "posterior" ])
176+ @pytest .mark .parametrize ("reuse_optimizer" , [True , False ])
177+ @pytest .mark .parametrize ("validation_sims" , [20 , None ])
178+ def test_train_rounds (self , mode , reuse_optimizer , validation_sims ):
179+ """Tests the offline training functionality."""
180+
181+ # Create trainer and data and train offline
182+ trainer = self .trainers [mode ]
183+ h = trainer .train_rounds (
184+ rounds = 2 ,
185+ sim_per_round = 32 ,
186+ epochs = 2 ,
187+ batch_size = 8 ,
188+ validation_sims = validation_sims ,
189+ reuse_optimizer = reuse_optimizer ,
190+ )
191+
192+ # Assert (non)-existence of optimizer
193+ if reuse_optimizer :
194+ assert trainer .optimizer is not None
195+ else :
196+ assert trainer .optimizer is None
197+
198+ # Ensure losses were stored in the correct format
199+ if validation_sims is None :
200+ assert type (h ) is DataFrame
201+ else :
202+ assert type (h ) is dict
203+ assert type (h ["train_losses" ]) is DataFrame
204+ assert type (h ["val_losses" ]) is DataFrame
205+
206+ @pytest .mark .parametrize ("reference_data" , [None , "dict" , "numpy" ])
207+ @pytest .mark .parametrize ("observed_data_type" , ["dict" , "numpy" ])
208+ @pytest .mark .parametrize ("bootstrap" , [True , False ])
209+ def mmd_hypothesis_test_no_reference (self , reference_data , observed_data_type , bootstrap ):
210+ trainer = self .trainers ["posterior" ]
211+ _ = trainer .train_online (epochs = 1 , iterations_per_epoch = 1 , batch_size = 4 )
212+
213+ num_reference_simulations = 10
214+ num_observed_simulations = 2
215+ num_null_samples = 5
216+
217+ if reference_data is None :
218+ if reference_data == "dict" :
219+ reference_data = trainer .configurator (trainer .generative_model (num_reference_simulations ))
220+ elif reference_data == "numpy" :
221+ reference_data = trainer .configurator (trainer .generative_model (num_reference_simulations ))['summary_conditions' ]
222+
223+ if observed_data_type == "dict" :
224+ observed_data = trainer .configurator (trainer .generative_model (num_observed_simulations ))
225+ elif observed_data_type == "numpy" :
226+ observed_data = trainer .configurator (trainer .generative_model (num_observed_simulations ))['summary_conditions' ]
227+
228+ MMD_sampling_distribution , MMD_observed = trainer .mmd_hypothesis_test (observed_data = observed_data ,
229+ reference_data = reference_data ,
230+ num_reference_simulations = num_reference_simulations ,
231+ num_null_samples = num_null_samples ,
232+ bootstrap = bootstrap )
233+
234+ assert MMD_sampling_distribution .shape [0 ] == num_reference_simulations
235+ assert np .all (MMD_sampling_distribution > 0 )
0 commit comments