Skip to content

Commit d0667ac

Browse files
committed
correct bug: exception should not be raised for num_null_samples > num_reference -> adjust to num_observed > num_reference
1 parent cd3e0f6 commit d0667ac

File tree

2 files changed

+8
-11
lines changed

2 files changed

+8
-11
lines changed

bayesflow/diagnostics/metrics/mmd_hypothesis_test.py

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -76,17 +76,16 @@ def compute_mmd_hypothesis_test_from_summaries(
7676
Raises
7777
------
7878
ValueError
79-
- If the number of null samples exceeds the number of reference samples or if the shapes of observed and
80-
reference summaries do not match.
79+
- If number of reference summaries is less than number of observed summaries.
8180
- If the shapes of observed and reference summaries do not match on dimensions besides the first one.
8281
"""
8382
num_observed: int = observed_summaries.shape[0]
8483
num_reference: int = reference_summaries.shape[0]
8584

86-
if num_null_samples > num_reference:
85+
if num_observed > num_reference:
8786
raise ValueError(
88-
f"Number of null samples ({num_null_samples}) cannot exceed"
89-
f"the number of reference samples ({num_reference})."
87+
f"Number of reference summaries ({num_reference}) must be greater than"
88+
f"number of observed summaries ({num_observed})."
9089
)
9190

9291
if observed_summaries.shape[1:] != reference_summaries.shape[1:]:

tests/test_diagnostics/test_diagnostics_metrics.py

Lines changed: 4 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -147,12 +147,10 @@ def test_compute_hypothesis_test_from_summaries_mismatched_shapes():
147147
)
148148

149149

150-
def test_compute_hypothesis_test_from_summaries_num_null_samples_exceeds_reference_samples():
151-
"""Test that compute_hypothesis_test_from_summaries raises ValueError when num_null_samples exceeds the number of
152-
reference samples.
153-
"""
154-
observed_summaries = np.random.rand(10, 5)
155-
reference_summaries = np.random.rand(5, 5)
150+
def test_compute_hypothesis_test_from_summaries_observed_larger_than_reference_summaries():
151+
"""Test that compute_hypothesis_test_from_summaries raises ValueError if observed is larger than reference."""
152+
observed_summaries = np.random.rand(20, 5)
153+
reference_summaries = np.random.rand(10, 5)
156154
num_null_samples = 10
157155

158156
with pytest.raises(ValueError):

0 commit comments

Comments
 (0)