@@ -247,3 +247,45 @@ def test_compute_hypothesis_test_different_distributions(summary_network, monkey
247247 )
248248
249249 assert mmd_observed >= np .quantile (mmd_null , 0.68 )
250+
251+
252+ @pytest .mark .parametrize ("summary_network" , [lambda data : np .random .rand (data .shape [0 ], 5 ), None ])
253+ def test_compute_hypothesis_test_mismatched_last_dimension (summary_network , monkeypatch ):
254+ """Test that a ValueError is raised for mismatched last dimensions."""
255+ observed_data = np .random .rand (10 , 5 )
256+ reference_data = np .random .rand (20 , 4 )
257+ num_null_samples = 10
258+
259+ mock_approximator = bf .approximators .ContinuousApproximator (
260+ adapter = None ,
261+ inference_network = None ,
262+ summary_network = None ,
263+ )
264+
265+ monkeypatch .setattr (mock_approximator , "summary_network" , summary_network )
266+
267+ with pytest .raises (ValueError ):
268+ bf .diagnostics .metrics .compute_mmd_hypothesis_test (
269+ observed_data , reference_data , mock_approximator , num_null_samples
270+ )
271+
272+
273+ @pytest .mark .parametrize ("summary_network" , [lambda data : np .random .rand (data .shape [0 ], 5 ), None ])
274+ def test_compute_hypothesis_test_num_null_samples_exceeds_reference_samples (summary_network , monkeypatch ):
275+ """Test that a ValueError is raised when num_null_samples exceeds the number of reference samples."""
276+ observed_data = np .random .rand (10 , 5 )
277+ reference_data = np .random .rand (5 , 5 )
278+ num_null_samples = 10
279+
280+ mock_approximator = bf .approximators .ContinuousApproximator (
281+ adapter = None ,
282+ inference_network = None ,
283+ summary_network = None ,
284+ )
285+
286+ monkeypatch .setattr (mock_approximator , "summary_network" , summary_network )
287+
288+ with pytest .raises (ValueError ):
289+ bf .diagnostics .metrics .compute_mmd_hypothesis_test (
290+ observed_data , reference_data , mock_approximator , num_null_samples
291+ )
0 commit comments