Skip to content

Commit 3c6c33c

Browse files
committed
add Raises to compute_hypothesis_test for unmatching observed and reference data + add corresponding test cases
1 parent 6fa1925 commit 3c6c33c

File tree

2 files changed

+12
-1
lines changed

2 files changed

+12
-1
lines changed

bayesflow/diagnostics/metrics/mmd_hypothesis_test.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -166,7 +166,18 @@ def compute_mmd_hypothesis_test(
166166
The MMD value between observed and reference data.
167167
mmd_null : np.ndarray
168168
A distribution of MMD values under the null hypothesis.
169+
170+
Raises:
171+
------
172+
ValueError
173+
- If the shapes of observed and reference data do not match on dimensions besides the first one.
169174
"""
175+
if observed_data.shape[1:] != reference_data.shape[1:]:
176+
raise ValueError(
177+
f"Expected observed and reference data to have the same shape, "
178+
f"but got {observed_data.shape[1:]} != {reference_data.shape[1:]}."
179+
)
180+
170181
if approximator.summary_network is not None:
171182
observed_data_tensor: Tensor = convert_to_tensor(observed_data)
172183
reference_data_tensor: Tensor = convert_to_tensor(reference_data)

tests/test_diagnostics/test_diagnostics_metrics.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -249,7 +249,7 @@ def test_compute_hypothesis_test_different_distributions(summary_network, monkey
249249
assert mmd_observed >= np.quantile(mmd_null, 0.68)
250250

251251

252-
@pytest.mark.parametrize("summary_network", [lambda data: np.random.rand(data.shape[0], 5), None])
252+
@pytest.mark.parametrize("summary_network", [lambda data: np.random.rand(data.shape[0], 5)])
253253
def test_compute_hypothesis_test_mismatched_last_dimension(summary_network, monkeypatch):
254254
"""Test that a ValueError is raised for mismatched last dimensions."""
255255
observed_data = np.random.rand(10, 5)

0 commit comments

Comments
 (0)