diff --git a/tests/statespace/core/test_statespace_JAX.py b/tests/statespace/core/test_statespace_JAX.py index c24504af8..abfafae55 100644 --- a/tests/statespace/core/test_statespace_JAX.py +++ b/tests/statespace/core/test_statespace_JAX.py @@ -7,6 +7,7 @@ import pytest from pymc.model.transform.optimization import freeze_dims_and_data +from pymc.testing import mock_sample_setup_and_teardown from pymc_extras.statespace.utils.constants import ( FILTER_OUTPUT_NAMES, @@ -30,6 +31,8 @@ nile = load_nile_test_data() ALL_SAMPLE_OUTPUTS = MATRIX_NAMES + FILTER_OUTPUT_NAMES + SMOOTHER_OUTPUT_NAMES +mock_pymc_sample = pytest.fixture(scope="session")(mock_sample_setup_and_teardown) + @pytest.fixture(scope="session") def pymc_mod(ss_mod): @@ -66,7 +69,7 @@ def exog_pymc_mod(exog_ss_mod, rng): @pytest.fixture(scope="session") -def idata(pymc_mod, rng): +def idata(pymc_mod, rng, mock_pymc_sample): with warnings.catch_warnings(): warnings.simplefilter("ignore") with pymc_mod: @@ -88,7 +91,7 @@ def idata(pymc_mod, rng): @pytest.fixture(scope="session") -def idata_exog(exog_pymc_mod, rng): +def idata_exog(exog_pymc_mod, rng, mock_pymc_sample): with warnings.catch_warnings(): warnings.simplefilter("ignore")