Skip to content

Commit cc4f247

Browse files
Use mock sampling throughout statespace tests (#518)
1 parent c35bc67 commit cc4f247

File tree

1 file changed

+5
-2
lines changed

1 file changed

+5
-2
lines changed

tests/statespace/core/test_statespace_JAX.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
import pytest
88

99
from pymc.model.transform.optimization import freeze_dims_and_data
10+
from pymc.testing import mock_sample_setup_and_teardown
1011

1112
from pymc_extras.statespace.utils.constants import (
1213
FILTER_OUTPUT_NAMES,
@@ -30,6 +31,8 @@
3031
nile = load_nile_test_data()
3132
ALL_SAMPLE_OUTPUTS = MATRIX_NAMES + FILTER_OUTPUT_NAMES + SMOOTHER_OUTPUT_NAMES
3233

34+
mock_pymc_sample = pytest.fixture(scope="session")(mock_sample_setup_and_teardown)
35+
3336

3437
@pytest.fixture(scope="session")
3538
def pymc_mod(ss_mod):
@@ -66,7 +69,7 @@ def exog_pymc_mod(exog_ss_mod, rng):
6669

6770

6871
@pytest.fixture(scope="session")
69-
def idata(pymc_mod, rng):
72+
def idata(pymc_mod, rng, mock_pymc_sample):
7073
with warnings.catch_warnings():
7174
warnings.simplefilter("ignore")
7275
with pymc_mod:
@@ -88,7 +91,7 @@ def idata(pymc_mod, rng):
8891

8992

9093
@pytest.fixture(scope="session")
91-
def idata_exog(exog_pymc_mod, rng):
94+
def idata_exog(exog_pymc_mod, rng, mock_pymc_sample):
9295
with warnings.catch_warnings():
9396
warnings.simplefilter("ignore")
9497

0 commit comments

Comments
 (0)