Skip to content

Commit dd725b7

Browse files
committed
add test cases for when approximator argument is bayesflow.networks.SummaryNetwork
1 parent 946e405 commit dd725b7

File tree

2 files changed

+76
-45
lines changed

2 files changed

+76
-45
lines changed

bayesflow/diagnostics/metrics/mmd_hypothesis_test.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -157,7 +157,7 @@ def compute_mmd_hypothesis_test(
157157
reference_data : np.ndarray
158158
Reference data, shape (num_reference, ...).
159159
approximator : ContinuousApproximator or SummaryNetwork
160-
An instance of the ContinuousApproximator or SummaryNetwork class used to extract summary statistics from data.
160+
An instance of the ContinuousApproximator or SummaryNetwork class use to extract summary statistics from data.
161161
num_null_samples : int
162162
Number of null samples to generate for hypothesis testing. Default is 100.
163163

tests/test_diagnostics/test_diagnostics_metrics.py

Lines changed: 75 additions & 44 deletions
Original file line numberDiff line numberDiff line change
@@ -161,20 +161,26 @@ def test_compute_hypothesis_test_from_summaries_num_null_samples_exceeds_referen
161161
)
162162

163163

164-
@pytest.mark.parametrize("summary_network", [lambda data: data + 1, None])
165-
def test_compute_hypothesis_test_shapes(summary_network, monkeypatch):
164+
@pytest.mark.parametrize(
165+
"summary_network, is_true_approximator",
166+
[(lambda data: data + 1, True), (None, True), (lambda data: data + 1, False)],
167+
)
168+
def test_compute_hypothesis_test_shapes(summary_network, is_true_approximator, monkeypatch):
166169
"""Test the compute_mmd_hypothesis_test output shapes."""
167170
observed_data = np.random.rand(10, 5)
168171
reference_data = np.random.rand(100, 5)
169172
num_null_samples = 50
170173

171-
mock_approximator = bf.approximators.ContinuousApproximator(
172-
adapter=None,
173-
inference_network=None,
174-
summary_network=None,
175-
)
176-
177-
monkeypatch.setattr(mock_approximator, "summary_network", summary_network)
174+
if is_true_approximator:
175+
mock_approximator = bf.approximators.ContinuousApproximator(
176+
adapter=None,
177+
inference_network=None,
178+
summary_network=None,
179+
)
180+
monkeypatch.setattr(mock_approximator, "summary_network", summary_network)
181+
else:
182+
mock_approximator = bf.networks.SummaryNetwork()
183+
monkeypatch.setattr(mock_approximator, "call", summary_network)
178184

179185
mmd_observed, mmd_null = bf.diagnostics.metrics.compute_mmd_hypothesis_test(
180186
observed_data, reference_data, mock_approximator, num_null_samples=num_null_samples
@@ -185,20 +191,26 @@ def test_compute_hypothesis_test_shapes(summary_network, monkeypatch):
185191
assert mmd_null.shape == (num_null_samples,)
186192

187193

188-
@pytest.mark.parametrize("summary_network", [lambda data: data + 1, None])
189-
def test_compute_hypothesis_test_positive(summary_network, monkeypatch):
194+
@pytest.mark.parametrize(
195+
"summary_network, is_true_approximator",
196+
[(lambda data: data + 1, True), (None, True), (lambda data: data + 1, False)],
197+
)
198+
def test_compute_hypothesis_test_positive(summary_network, is_true_approximator, monkeypatch):
190199
"""Test MMD output values of compute_hypothesis_test are positive."""
191200
observed_data = np.random.rand(10, 5)
192201
reference_data = np.random.rand(100, 5)
193202
num_null_samples = 50
194203

195-
mock_approximator = bf.approximators.ContinuousApproximator(
196-
adapter=None,
197-
inference_network=None,
198-
summary_network=None,
199-
)
200-
201-
monkeypatch.setattr(mock_approximator, "summary_network", summary_network)
204+
if is_true_approximator:
205+
mock_approximator = bf.approximators.ContinuousApproximator(
206+
adapter=None,
207+
inference_network=None,
208+
summary_network=None,
209+
)
210+
monkeypatch.setattr(mock_approximator, "summary_network", summary_network)
211+
else:
212+
mock_approximator = bf.networks.SummaryNetwork()
213+
monkeypatch.setattr(mock_approximator, "call", summary_network)
202214

203215
mmd_observed, mmd_null = bf.diagnostics.metrics.compute_mmd_hypothesis_test(
204216
observed_data, reference_data, mock_approximator, num_null_samples=num_null_samples
@@ -208,20 +220,26 @@ def test_compute_hypothesis_test_positive(summary_network, monkeypatch):
208220
assert np.all(mmd_null >= 0)
209221

210222

211-
@pytest.mark.parametrize("summary_network", [lambda data: data + 1, None])
212-
def test_compute_hypothesis_test_same_distribution(summary_network, monkeypatch):
223+
@pytest.mark.parametrize(
224+
"summary_network, is_true_approximator",
225+
[(lambda data: data + 1, True), (None, True), (lambda data: data + 1, False)],
226+
)
227+
def test_compute_hypothesis_test_same_distribution(summary_network, is_true_approximator, monkeypatch):
213228
"""Test compute_hypothesis_test on same distributions."""
214229
observed_data = np.random.rand(10, 5)
215230
reference_data = observed_data.copy()
216231
num_null_samples = 5
217232

218-
mock_approximator = bf.approximators.ContinuousApproximator(
219-
adapter=None,
220-
inference_network=None,
221-
summary_network=None,
222-
)
223-
224-
monkeypatch.setattr(mock_approximator, "summary_network", summary_network)
233+
if is_true_approximator:
234+
mock_approximator = bf.approximators.ContinuousApproximator(
235+
adapter=None,
236+
inference_network=None,
237+
summary_network=None,
238+
)
239+
monkeypatch.setattr(mock_approximator, "summary_network", summary_network)
240+
else:
241+
mock_approximator = bf.networks.SummaryNetwork()
242+
monkeypatch.setattr(mock_approximator, "call", summary_network)
225243

226244
mmd_observed, mmd_null = bf.diagnostics.metrics.compute_mmd_hypothesis_test(
227245
observed_data, reference_data, mock_approximator, num_null_samples=num_null_samples
@@ -230,19 +248,26 @@ def test_compute_hypothesis_test_same_distribution(summary_network, monkeypatch)
230248
assert mmd_observed <= np.quantile(mmd_null, 0.99)
231249

232250

233-
@pytest.mark.parametrize("summary_network", [lambda data: data + 1, None])
234-
def test_compute_hypothesis_test_different_distributions(summary_network, monkeypatch):
251+
@pytest.mark.parametrize(
252+
"summary_network, is_true_approximator",
253+
[(lambda data: data + 1, True), (None, True), (lambda data: data + 1, False)],
254+
)
255+
def test_compute_hypothesis_test_different_distributions(summary_network, is_true_approximator, monkeypatch):
235256
"""Test compute_hypothesis_test on different distributions."""
236257
observed_data = np.random.rand(10, 5)
237258
reference_data = np.random.normal(loc=0.5, scale=0.1, size=(100, 5))
238259
num_null_samples = 50
239260

240-
mock_approximator = bf.approximators.ContinuousApproximator(
241-
adapter=None,
242-
inference_network=None,
243-
summary_network=None,
244-
)
245-
monkeypatch.setattr(mock_approximator, "summary_network", summary_network)
261+
if is_true_approximator:
262+
mock_approximator = bf.approximators.ContinuousApproximator(
263+
adapter=None,
264+
inference_network=None,
265+
summary_network=None,
266+
)
267+
monkeypatch.setattr(mock_approximator, "summary_network", summary_network)
268+
else:
269+
mock_approximator = bf.networks.SummaryNetwork()
270+
monkeypatch.setattr(mock_approximator, "call", summary_network)
246271

247272
mmd_observed, mmd_null = bf.diagnostics.metrics.compute_mmd_hypothesis_test(
248273
observed_data, reference_data, mock_approximator, num_null_samples=num_null_samples
@@ -251,20 +276,26 @@ def test_compute_hypothesis_test_different_distributions(summary_network, monkey
251276
assert mmd_observed >= np.quantile(mmd_null, 0.68)
252277

253278

254-
@pytest.mark.parametrize("summary_network", [lambda data: data + 1, None])
255-
def test_compute_hypothesis_test_mismatched_shapes(summary_network, monkeypatch):
279+
@pytest.mark.parametrize(
280+
"summary_network, is_true_approximator",
281+
[(lambda data: data + 1, True), (None, True), (lambda data: data + 1, False)],
282+
)
283+
def test_compute_hypothesis_test_mismatched_shapes(summary_network, is_true_approximator, monkeypatch):
256284
"""Test that compute_hypothesis_test raises ValueError for mismatched shapes."""
257285
observed_data = np.random.rand(10, 5)
258286
reference_data = np.random.rand(20, 4)
259287
num_null_samples = 10
260288

261-
mock_approximator = bf.approximators.ContinuousApproximator(
262-
adapter=None,
263-
inference_network=None,
264-
summary_network=None,
265-
)
266-
267-
monkeypatch.setattr(mock_approximator, "summary_network", summary_network)
289+
if is_true_approximator:
290+
mock_approximator = bf.approximators.ContinuousApproximator(
291+
adapter=None,
292+
inference_network=None,
293+
summary_network=None,
294+
)
295+
monkeypatch.setattr(mock_approximator, "summary_network", summary_network)
296+
else:
297+
mock_approximator = bf.networks.SummaryNetwork()
298+
monkeypatch.setattr(mock_approximator, "call", summary_network)
268299

269300
with pytest.raises(ValueError):
270301
bf.diagnostics.metrics.compute_mmd_hypothesis_test(

0 commit comments

Comments
 (0)