Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 5 additions & 2 deletions tests/statespace/core/test_statespace_JAX.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
import pytest

from pymc.model.transform.optimization import freeze_dims_and_data
from pymc.testing import mock_sample_setup_and_teardown

from pymc_extras.statespace.utils.constants import (
FILTER_OUTPUT_NAMES,
Expand All @@ -30,6 +31,8 @@
nile = load_nile_test_data()
ALL_SAMPLE_OUTPUTS = MATRIX_NAMES + FILTER_OUTPUT_NAMES + SMOOTHER_OUTPUT_NAMES

mock_pymc_sample = pytest.fixture(scope="session")(mock_sample_setup_and_teardown)


@pytest.fixture(scope="session")
def pymc_mod(ss_mod):
Expand Down Expand Up @@ -66,7 +69,7 @@ def exog_pymc_mod(exog_ss_mod, rng):


@pytest.fixture(scope="session")
def idata(pymc_mod, rng):
def idata(pymc_mod, rng, mock_pymc_sample):
with warnings.catch_warnings():
warnings.simplefilter("ignore")
with pymc_mod:
Expand All @@ -88,7 +91,7 @@ def idata(pymc_mod, rng):


@pytest.fixture(scope="session")
def idata_exog(exog_pymc_mod, rng):
def idata_exog(exog_pymc_mod, rng, mock_pymc_sample):
with warnings.catch_warnings():
warnings.simplefilter("ignore")

Expand Down