Skip to content

Commit 252f70a

Browse files
committed
added mock_sample_setup_and_teardown to statespace tests
1 parent 4119a0d commit 252f70a

File tree

1 file changed

+6
-4
lines changed

1 file changed

+6
-4
lines changed

tests/statespace/core/test_statespace.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
import pytest
1010

1111
from numpy.testing import assert_allclose
12+
from pymc.testing import mock_sample_setup_and_teardown
1213
from pytensor.compile import SharedVariable
1314
from pytensor.graph.basic import graph_inputs
1415

@@ -32,6 +33,7 @@
3233
floatX = pytensor.config.floatX
3334
nile = load_nile_test_data()
3435
ALL_SAMPLE_OUTPUTS = MATRIX_NAMES + FILTER_OUTPUT_NAMES + SMOOTHER_OUTPUT_NAMES
36+
mock_pymc_sample = pytest.fixture(scope="session")(mock_sample_setup_and_teardown)
3537

3638

3739
def make_statespace_mod(k_endog, k_states, k_posdef, filter_type, verbose=False, data_info=None):
@@ -214,7 +216,7 @@ def pymc_mod_no_exog_dt(ss_mod_no_exog_dt, rng):
214216

215217

216218
@pytest.fixture(scope="session")
217-
def idata(pymc_mod, rng):
219+
def idata(pymc_mod, rng, mock_pymc_sample):
218220
with pymc_mod:
219221
idata = pm.sample(draws=10, tune=0, chains=1, random_seed=rng)
220222
idata_prior = pm.sample_prior_predictive(draws=10, random_seed=rng)
@@ -224,7 +226,7 @@ def idata(pymc_mod, rng):
224226

225227

226228
@pytest.fixture(scope="session")
227-
def idata_exog(exog_pymc_mod, rng):
229+
def idata_exog(exog_pymc_mod, rng, mock_pymc_sample):
228230
with exog_pymc_mod:
229231
idata = pm.sample(draws=10, tune=0, chains=1, random_seed=rng)
230232
idata_prior = pm.sample_prior_predictive(draws=10, random_seed=rng)
@@ -233,7 +235,7 @@ def idata_exog(exog_pymc_mod, rng):
233235

234236

235237
@pytest.fixture(scope="session")
236-
def idata_no_exog(pymc_mod_no_exog, rng):
238+
def idata_no_exog(pymc_mod_no_exog, rng, mock_pymc_sample):
237239
with pymc_mod_no_exog:
238240
idata = pm.sample(draws=10, tune=0, chains=1, random_seed=rng)
239241
idata_prior = pm.sample_prior_predictive(draws=10, random_seed=rng)
@@ -242,7 +244,7 @@ def idata_no_exog(pymc_mod_no_exog, rng):
242244

243245

244246
@pytest.fixture(scope="session")
245-
def idata_no_exog_dt(pymc_mod_no_exog_dt, rng):
247+
def idata_no_exog_dt(pymc_mod_no_exog_dt, rng, mock_pymc_sample):
246248
with pymc_mod_no_exog_dt:
247249
idata = pm.sample(draws=10, tune=0, chains=1, random_seed=rng)
248250
idata_prior = pm.sample_prior_predictive(draws=10, random_seed=rng)

0 commit comments

Comments
 (0)