Skip to content

Commit f67da8a

Browse files
committed
adjust mock summary_network in unit tests to be a deterministic transform
1 parent b29a365 commit f67da8a

File tree

1 file changed

+3
-3
lines changed

1 file changed

+3
-3
lines changed

tests/test_diagnostics/test_diagnostics_metrics.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -161,7 +161,7 @@ def test_compute_hypothesis_test_from_summaries_num_null_samples_exceeds_referen
161161
)
162162

163163

164-
@pytest.mark.parametrize("summary_network", [lambda data: np.random.rand(data.shape[0], 5), None])
164+
@pytest.mark.parametrize("summary_network", [lambda data: data + 1, None])
165165
def test_compute_hypothesis_test_shapes(summary_network, monkeypatch):
166166
"""Test the compute_mmd_hypothesis_test output shapes."""
167167
observed_data = np.random.rand(10, 5)
@@ -185,7 +185,7 @@ def test_compute_hypothesis_test_shapes(summary_network, monkeypatch):
185185
assert mmd_null.shape == (num_null_samples,)
186186

187187

188-
@pytest.mark.parametrize("summary_network", [lambda data: np.random.rand(data.shape[0], 5), None])
188+
@pytest.mark.parametrize("summary_network", [lambda data: data + 1, None])
189189
def test_compute_hypothesis_test_positive(summary_network, monkeypatch):
190190
"""Test MMD output values of compute_hypothesis_test are positive."""
191191
observed_data = np.random.rand(10, 5)
@@ -251,7 +251,7 @@ def test_compute_hypothesis_test_different_distributions(summary_network, monkey
251251
assert mmd_observed >= np.quantile(mmd_null, 0.68)
252252

253253

254-
@pytest.mark.parametrize("summary_network", [lambda data: np.random.rand(data.shape[0], 5)])
254+
@pytest.mark.parametrize("summary_network", [lambda data: data + 1, None])
255255
def test_compute_hypothesis_test_mismatched_shapes(summary_network, monkeypatch):
256256
"""Test that compute_hypothesis_test raises ValueError for mismatched shapes."""
257257
observed_data = np.random.rand(10, 5)

0 commit comments

Comments
 (0)