Skip to content

Commit f5687e8

Browse files
committed
add unit test case for when ContinuousApproximator.summary_network = None
1 parent ef9dd00 commit f5687e8

File tree

2 files changed

+9
-12
lines changed

2 files changed

+9
-12
lines changed

bayesflow/diagnostics/metrics/mmd_hypothesis_test.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -154,8 +154,8 @@ def compute_mmd_hypothesis_test(
154154
observed_summaries: np.ndarray = convert_to_numpy(approximator.summary_network(observed_data_tensor))
155155
reference_summaries: np.ndarray = convert_to_numpy(approximator.summary_network(reference_data_tensor))
156156
else:
157-
observed_summaries: np.ndarray = convert_to_numpy(observed_data_tensor)
158-
reference_summaries: np.ndarray = convert_to_numpy(reference_data_tensor)
157+
observed_summaries: np.ndarray = observed_data
158+
reference_summaries: np.ndarray = reference_data
159159

160160
mmd_observed, mmd_null = compute_mmd_hypothesis_test_from_summaries(
161161
observed_summaries, reference_summaries, num_null_samples=num_null_samples

tests/test_diagnostics/test_diagnostics_metrics.py

Lines changed: 7 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -97,26 +97,23 @@ def test_compute_hypothesis_test_from_summaries_shapes():
9797
assert mmd_null.shape == (num_null_samples,)
9898

9999

100-
def test_compute_hypothesis_test_shapes_with_mock(monkeypatch):
101-
"""Test the compute_mmd_hypothesis_test output shapes using pytest monkeypatch."""
100+
@pytest.mark.parametrize("summary_network", [lambda data: np.random.rand(data.shape[0], 5), None])
101+
def test_compute_hypothesis_test_shapes(summary_network, monkeypatch):
102+
"""Test the compute_mmd_hypothesis_test output shapes."""
102103
# Mock observed and reference data
103104
observed_data = np.random.rand(10, 5)
104105
reference_data = np.random.rand(100, 5)
105106
num_null_samples = 50
106107

107-
# Mock the summary_network method
108-
def mock_summary_network(data):
109-
return np.random.rand(data.shape[0], 5)
110-
111108
# Create a dummy ContinuousApproximator instance
112109
mock_approximator = bf.approximators.ContinuousApproximator(
113-
adapter=None, # Pass None or a mock Adapter if required
114-
inference_network=None, # Pass None or a mock InferenceNetwork if required
115-
summary_network=None, # This will be replaced by the monkeypatched method
110+
adapter=None,
111+
inference_network=None,
112+
summary_network=None,
116113
)
117114

118115
# Patch the summary_network attribute of the mock_approximator instance
119-
monkeypatch.setattr(mock_approximator, "summary_network", mock_summary_network)
116+
monkeypatch.setattr(mock_approximator, "summary_network", summary_network)
120117

121118
# Call the function under test
122119
mmd_observed, mmd_null = bf.diagnostics.metrics.compute_mmd_hypothesis_test(

0 commit comments

Comments
 (0)