Skip to content

Commit 6fa1925

Browse files
committed
add test cases for indirect Raises through compute_hypothesis_test
1 parent ab0b895 commit 6fa1925

File tree

1 file changed

+42
-0
lines changed

1 file changed

+42
-0
lines changed

tests/test_diagnostics/test_diagnostics_metrics.py

Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -247,3 +247,45 @@ def test_compute_hypothesis_test_different_distributions(summary_network, monkey
247247
)
248248

249249
assert mmd_observed >= np.quantile(mmd_null, 0.68)
250+
251+
252+
@pytest.mark.parametrize("summary_network", [lambda data: np.random.rand(data.shape[0], 5), None])
253+
def test_compute_hypothesis_test_mismatched_last_dimension(summary_network, monkeypatch):
254+
"""Test that a ValueError is raised for mismatched last dimensions."""
255+
observed_data = np.random.rand(10, 5)
256+
reference_data = np.random.rand(20, 4)
257+
num_null_samples = 10
258+
259+
mock_approximator = bf.approximators.ContinuousApproximator(
260+
adapter=None,
261+
inference_network=None,
262+
summary_network=None,
263+
)
264+
265+
monkeypatch.setattr(mock_approximator, "summary_network", summary_network)
266+
267+
with pytest.raises(ValueError):
268+
bf.diagnostics.metrics.compute_mmd_hypothesis_test(
269+
observed_data, reference_data, mock_approximator, num_null_samples
270+
)
271+
272+
273+
@pytest.mark.parametrize("summary_network", [lambda data: np.random.rand(data.shape[0], 5), None])
274+
def test_compute_hypothesis_test_num_null_samples_exceeds_reference_samples(summary_network, monkeypatch):
275+
"""Test that a ValueError is raised when num_null_samples exceeds the number of reference samples."""
276+
observed_data = np.random.rand(10, 5)
277+
reference_data = np.random.rand(5, 5)
278+
num_null_samples = 10
279+
280+
mock_approximator = bf.approximators.ContinuousApproximator(
281+
adapter=None,
282+
inference_network=None,
283+
summary_network=None,
284+
)
285+
286+
monkeypatch.setattr(mock_approximator, "summary_network", summary_network)
287+
288+
with pytest.raises(ValueError):
289+
bf.diagnostics.metrics.compute_mmd_hypothesis_test(
290+
observed_data, reference_data, mock_approximator, num_null_samples
291+
)

0 commit comments

Comments
 (0)