@@ -161,20 +161,26 @@ def test_compute_hypothesis_test_from_summaries_num_null_samples_exceeds_referen
161161 )
162162
163163
164- @pytest .mark .parametrize ("summary_network" , [lambda data : data + 1 , None ])
165- def test_compute_hypothesis_test_shapes (summary_network , monkeypatch ):
164+ @pytest .mark .parametrize (
165+ "summary_network, is_true_approximator" ,
166+ [(lambda data : data + 1 , True ), (None , True ), (lambda data : data + 1 , False )],
167+ )
168+ def test_compute_hypothesis_test_shapes (summary_network , is_true_approximator , monkeypatch ):
166169 """Test the compute_mmd_hypothesis_test output shapes."""
167170 observed_data = np .random .rand (10 , 5 )
168171 reference_data = np .random .rand (100 , 5 )
169172 num_null_samples = 50
170173
171- mock_approximator = bf .approximators .ContinuousApproximator (
172- adapter = None ,
173- inference_network = None ,
174- summary_network = None ,
175- )
176-
177- monkeypatch .setattr (mock_approximator , "summary_network" , summary_network )
174+ if is_true_approximator :
175+ mock_approximator = bf .approximators .ContinuousApproximator (
176+ adapter = None ,
177+ inference_network = None ,
178+ summary_network = None ,
179+ )
180+ monkeypatch .setattr (mock_approximator , "summary_network" , summary_network )
181+ else :
182+ mock_approximator = bf .networks .SummaryNetwork ()
183+ monkeypatch .setattr (mock_approximator , "call" , summary_network )
178184
179185 mmd_observed , mmd_null = bf .diagnostics .metrics .compute_mmd_hypothesis_test (
180186 observed_data , reference_data , mock_approximator , num_null_samples = num_null_samples
@@ -185,20 +191,26 @@ def test_compute_hypothesis_test_shapes(summary_network, monkeypatch):
185191 assert mmd_null .shape == (num_null_samples ,)
186192
187193
188- @pytest .mark .parametrize ("summary_network" , [lambda data : data + 1 , None ])
189- def test_compute_hypothesis_test_positive (summary_network , monkeypatch ):
194+ @pytest .mark .parametrize (
195+ "summary_network, is_true_approximator" ,
196+ [(lambda data : data + 1 , True ), (None , True ), (lambda data : data + 1 , False )],
197+ )
198+ def test_compute_hypothesis_test_positive (summary_network , is_true_approximator , monkeypatch ):
190199 """Test MMD output values of compute_hypothesis_test are positive."""
191200 observed_data = np .random .rand (10 , 5 )
192201 reference_data = np .random .rand (100 , 5 )
193202 num_null_samples = 50
194203
195- mock_approximator = bf .approximators .ContinuousApproximator (
196- adapter = None ,
197- inference_network = None ,
198- summary_network = None ,
199- )
200-
201- monkeypatch .setattr (mock_approximator , "summary_network" , summary_network )
204+ if is_true_approximator :
205+ mock_approximator = bf .approximators .ContinuousApproximator (
206+ adapter = None ,
207+ inference_network = None ,
208+ summary_network = None ,
209+ )
210+ monkeypatch .setattr (mock_approximator , "summary_network" , summary_network )
211+ else :
212+ mock_approximator = bf .networks .SummaryNetwork ()
213+ monkeypatch .setattr (mock_approximator , "call" , summary_network )
202214
203215 mmd_observed , mmd_null = bf .diagnostics .metrics .compute_mmd_hypothesis_test (
204216 observed_data , reference_data , mock_approximator , num_null_samples = num_null_samples
@@ -208,20 +220,26 @@ def test_compute_hypothesis_test_positive(summary_network, monkeypatch):
208220 assert np .all (mmd_null >= 0 )
209221
210222
211- @pytest .mark .parametrize ("summary_network" , [lambda data : data + 1 , None ])
212- def test_compute_hypothesis_test_same_distribution (summary_network , monkeypatch ):
223+ @pytest .mark .parametrize (
224+ "summary_network, is_true_approximator" ,
225+ [(lambda data : data + 1 , True ), (None , True ), (lambda data : data + 1 , False )],
226+ )
227+ def test_compute_hypothesis_test_same_distribution (summary_network , is_true_approximator , monkeypatch ):
213228 """Test compute_hypothesis_test on same distributions."""
214229 observed_data = np .random .rand (10 , 5 )
215230 reference_data = observed_data .copy ()
216231 num_null_samples = 5
217232
218- mock_approximator = bf .approximators .ContinuousApproximator (
219- adapter = None ,
220- inference_network = None ,
221- summary_network = None ,
222- )
223-
224- monkeypatch .setattr (mock_approximator , "summary_network" , summary_network )
233+ if is_true_approximator :
234+ mock_approximator = bf .approximators .ContinuousApproximator (
235+ adapter = None ,
236+ inference_network = None ,
237+ summary_network = None ,
238+ )
239+ monkeypatch .setattr (mock_approximator , "summary_network" , summary_network )
240+ else :
241+ mock_approximator = bf .networks .SummaryNetwork ()
242+ monkeypatch .setattr (mock_approximator , "call" , summary_network )
225243
226244 mmd_observed , mmd_null = bf .diagnostics .metrics .compute_mmd_hypothesis_test (
227245 observed_data , reference_data , mock_approximator , num_null_samples = num_null_samples
@@ -230,19 +248,26 @@ def test_compute_hypothesis_test_same_distribution(summary_network, monkeypatch)
230248 assert mmd_observed <= np .quantile (mmd_null , 0.99 )
231249
232250
233- @pytest .mark .parametrize ("summary_network" , [lambda data : data + 1 , None ])
234- def test_compute_hypothesis_test_different_distributions (summary_network , monkeypatch ):
251+ @pytest .mark .parametrize (
252+ "summary_network, is_true_approximator" ,
253+ [(lambda data : data + 1 , True ), (None , True ), (lambda data : data + 1 , False )],
254+ )
255+ def test_compute_hypothesis_test_different_distributions (summary_network , is_true_approximator , monkeypatch ):
235256 """Test compute_hypothesis_test on different distributions."""
236257 observed_data = np .random .rand (10 , 5 )
237258 reference_data = np .random .normal (loc = 0.5 , scale = 0.1 , size = (100 , 5 ))
238259 num_null_samples = 50
239260
240- mock_approximator = bf .approximators .ContinuousApproximator (
241- adapter = None ,
242- inference_network = None ,
243- summary_network = None ,
244- )
245- monkeypatch .setattr (mock_approximator , "summary_network" , summary_network )
261+ if is_true_approximator :
262+ mock_approximator = bf .approximators .ContinuousApproximator (
263+ adapter = None ,
264+ inference_network = None ,
265+ summary_network = None ,
266+ )
267+ monkeypatch .setattr (mock_approximator , "summary_network" , summary_network )
268+ else :
269+ mock_approximator = bf .networks .SummaryNetwork ()
270+ monkeypatch .setattr (mock_approximator , "call" , summary_network )
246271
247272 mmd_observed , mmd_null = bf .diagnostics .metrics .compute_mmd_hypothesis_test (
248273 observed_data , reference_data , mock_approximator , num_null_samples = num_null_samples
@@ -251,20 +276,26 @@ def test_compute_hypothesis_test_different_distributions(summary_network, monkey
251276 assert mmd_observed >= np .quantile (mmd_null , 0.68 )
252277
253278
254- @pytest .mark .parametrize ("summary_network" , [lambda data : data + 1 , None ])
255- def test_compute_hypothesis_test_mismatched_shapes (summary_network , monkeypatch ):
279+ @pytest .mark .parametrize (
280+ "summary_network, is_true_approximator" ,
281+ [(lambda data : data + 1 , True ), (None , True ), (lambda data : data + 1 , False )],
282+ )
283+ def test_compute_hypothesis_test_mismatched_shapes (summary_network , is_true_approximator , monkeypatch ):
256284 """Test that compute_hypothesis_test raises ValueError for mismatched shapes."""
257285 observed_data = np .random .rand (10 , 5 )
258286 reference_data = np .random .rand (20 , 4 )
259287 num_null_samples = 10
260288
261- mock_approximator = bf .approximators .ContinuousApproximator (
262- adapter = None ,
263- inference_network = None ,
264- summary_network = None ,
265- )
266-
267- monkeypatch .setattr (mock_approximator , "summary_network" , summary_network )
289+ if is_true_approximator :
290+ mock_approximator = bf .approximators .ContinuousApproximator (
291+ adapter = None ,
292+ inference_network = None ,
293+ summary_network = None ,
294+ )
295+ monkeypatch .setattr (mock_approximator , "summary_network" , summary_network )
296+ else :
297+ mock_approximator = bf .networks .SummaryNetwork ()
298+ monkeypatch .setattr (mock_approximator , "call" , summary_network )
268299
269300 with pytest .raises (ValueError ):
270301 bf .diagnostics .metrics .compute_mmd_hypothesis_test (
0 commit comments