Skip to content

Commit 025c8bd

Browse files
committed
- raise exception for num_null_samples zero or negative
- add appropriate test cases
1 parent d0667ac commit 025c8bd

File tree

2 files changed

+31
-2
lines changed

2 files changed

+31
-2
lines changed

bayesflow/diagnostics/metrics/mmd_hypothesis_test.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -78,6 +78,7 @@ def compute_mmd_hypothesis_test_from_summaries(
7878
ValueError
7979
- If number of reference summaries is less than number of observed summaries.
8080
- If the shapes of observed and reference summaries do not match on dimensions besides the first one.
81+
- If number of null samples is less than or equal to 0.
8182
"""
8283
num_observed: int = observed_summaries.shape[0]
8384
num_reference: int = reference_summaries.shape[0]
@@ -94,6 +95,9 @@ def compute_mmd_hypothesis_test_from_summaries(
9495
f"but got {observed_summaries.shape[1:]} != {reference_summaries.shape[1:]}."
9596
)
9697

98+
if num_null_samples <= 0:
99+
raise ValueError(f"Number of null samples must be greater than 0, but got {num_null_samples}.")
100+
97101
observed_summaries_tensor: Tensor = convert_to_tensor(observed_summaries, dtype="float32")
98102
reference_summaries_tensor: Tensor = convert_to_tensor(reference_summaries, dtype="float32")
99103

tests/test_diagnostics/test_diagnostics_metrics.py

Lines changed: 27 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -147,8 +147,9 @@ def test_compute_hypothesis_test_from_summaries_mismatched_shapes():
147147
)
148148

149149

150-
def test_compute_hypothesis_test_from_summaries_observed_larger_than_reference_summaries():
151-
"""Test that compute_hypothesis_test_from_summaries raises ValueError if observed is larger than reference."""
150+
def test_compute_hypothesis_test_from_summaries_reference_smaller_than_observed_summaries():
151+
"""Test that compute_hypothesis_test_from_summaries raises ValueError if number of reference summaries smaller than
152+
observed."""
152153
observed_summaries = np.random.rand(20, 5)
153154
reference_summaries = np.random.rand(10, 5)
154155
num_null_samples = 10
@@ -159,6 +160,30 @@ def test_compute_hypothesis_test_from_summaries_observed_larger_than_reference_s
159160
)
160161

161162

163+
def test_compute_hypothesis_test_from_summaries_num_null_samples_zero():
164+
"""Test that compute_hypothesis_test_from_summaries raises ValueError if num_null_samples is zero."""
165+
observed_summaries = np.random.rand(20, 5)
166+
reference_summaries = np.random.rand(10, 5)
167+
num_null_samples = 0
168+
169+
with pytest.raises(ValueError):
170+
bf.diagnostics.metrics.compute_mmd_hypothesis_test_from_summaries(
171+
observed_summaries, reference_summaries, num_null_samples
172+
)
173+
174+
175+
def test_compute_hypothesis_test_from_summaries_num_null_samples_negative():
176+
"""Test that compute_hypothesis_test_from_summaries raises ValueError if num_null_samples is negative."""
177+
observed_summaries = np.random.rand(20, 5)
178+
reference_summaries = np.random.rand(10, 5)
179+
num_null_samples = -1
180+
181+
with pytest.raises(ValueError):
182+
bf.diagnostics.metrics.compute_mmd_hypothesis_test_from_summaries(
183+
observed_summaries, reference_summaries, num_null_samples
184+
)
185+
186+
162187
@pytest.mark.parametrize(
163188
"summary_network, is_true_approximator",
164189
[(lambda data: data + 1, True), (None, True), (lambda data: data + 1, False)],

0 commit comments

Comments
 (0)