Skip to content

Commit 4b92331

Browse files
Merge branch 'pymc-devs:main' into implement-minimiser-for-INLA
2 parents 35e525e + dcad7f2 commit 4b92331

File tree

3 files changed

+8
-4
lines changed

3 files changed

+8
-4
lines changed

conda-envs/environment-test.yml

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,6 @@ dependencies:
1212
- numba
1313
- pytest
1414
- pytest-cov
15-
- libgcc<15
1615
- pydantic>=2.0.0
1716
- preliz
1817
- pip

pymc_extras/linearmodel.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,10 +2,12 @@
22
import pandas as pd
33
import pymc as pm
44

5+
from sklearn.base import BaseEstimator
6+
57
from pymc_extras.model_builder import ModelBuilder
68

79

8-
class LinearModel(ModelBuilder):
10+
class LinearModel(ModelBuilder, BaseEstimator):
911
def __init__(
1012
self, model_config: dict | None = None, sampler_config: dict | None = None, nsamples=100
1113
):

tests/statespace/core/test_statespace_JAX.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
import pytest
88

99
from pymc.model.transform.optimization import freeze_dims_and_data
10+
from pymc.testing import mock_sample_setup_and_teardown
1011

1112
from pymc_extras.statespace.utils.constants import (
1213
FILTER_OUTPUT_NAMES,
@@ -30,6 +31,8 @@
3031
nile = load_nile_test_data()
3132
ALL_SAMPLE_OUTPUTS = MATRIX_NAMES + FILTER_OUTPUT_NAMES + SMOOTHER_OUTPUT_NAMES
3233

34+
mock_pymc_sample = pytest.fixture(scope="session")(mock_sample_setup_and_teardown)
35+
3336

3437
@pytest.fixture(scope="session")
3538
def pymc_mod(ss_mod):
@@ -66,7 +69,7 @@ def exog_pymc_mod(exog_ss_mod, rng):
6669

6770

6871
@pytest.fixture(scope="session")
69-
def idata(pymc_mod, rng):
72+
def idata(pymc_mod, rng, mock_pymc_sample):
7073
with warnings.catch_warnings():
7174
warnings.simplefilter("ignore")
7275
with pymc_mod:
@@ -88,7 +91,7 @@ def idata(pymc_mod, rng):
8891

8992

9093
@pytest.fixture(scope="session")
91-
def idata_exog(exog_pymc_mod, rng):
94+
def idata_exog(exog_pymc_mod, rng, mock_pymc_sample):
9295
with warnings.catch_warnings():
9396
warnings.simplefilter("ignore")
9497

0 commit comments

Comments
 (0)