@@ -97,26 +97,23 @@ def test_compute_hypothesis_test_from_summaries_shapes():
9797 assert mmd_null .shape == (num_null_samples ,)
9898
9999
100- def test_compute_hypothesis_test_shapes_with_mock (monkeypatch ):
101- """Test the compute_mmd_hypothesis_test output shapes using pytest monkeypatch."""
100+ @pytest .mark .parametrize ("summary_network" , [lambda data : np .random .rand (data .shape [0 ], 5 ), None ])
101+ def test_compute_hypothesis_test_shapes (summary_network , monkeypatch ):
102+ """Test the compute_mmd_hypothesis_test output shapes."""
102103 # Mock observed and reference data
103104 observed_data = np .random .rand (10 , 5 )
104105 reference_data = np .random .rand (100 , 5 )
105106 num_null_samples = 50
106107
107- # Mock the summary_network method
108- def mock_summary_network (data ):
109- return np .random .rand (data .shape [0 ], 5 )
110-
111108 # Create a dummy ContinuousApproximator instance
112109 mock_approximator = bf .approximators .ContinuousApproximator (
113- adapter = None , # Pass None or a mock Adapter if required
114- inference_network = None , # Pass None or a mock InferenceNetwork if required
115- summary_network = None , # This will be replaced by the monkeypatched method
110+ adapter = None ,
111+ inference_network = None ,
112+ summary_network = None ,
116113 )
117114
118115 # Patch the summary_network attribute of the mock_approximator instance
119- monkeypatch .setattr (mock_approximator , "summary_network" , mock_summary_network )
116+ monkeypatch .setattr (mock_approximator , "summary_network" , summary_network )
120117
121118 # Call the function under test
122119 mmd_observed , mmd_null = bf .diagnostics .metrics .compute_mmd_hypothesis_test (
0 commit comments