Skip to content

Commit e163a6f

Browse files
Fix CI errors in statspace test suite (#251)
* Ignore numpy depreciation warnings from statsmodel * Squeeze result vector-matrix multiplication with (1, 1) matrix to avoid shape error in numpy 1.25.2 * Consolidate all project settings into `pyproject.toml` * Delete unused`pytest.ini` and `setup.cfg` * Remove unnecessary filtered warnings * Change pathfinder jax import from deprecated pymc.sampling_jax to pymc.sampling.jax * Skip pathfinder test if python < 3.10 * Add some comments to `pyproject.toml` to explain what warnings are being ignored and why
1 parent 3444ede commit e163a6f

File tree

7 files changed

+16
-21
lines changed

7 files changed

+16
-21
lines changed

pymc_experimental/inference/pathfinder.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@
2525
import numpy as np
2626
import pymc as pm
2727
from pymc import modelcontext
28-
from pymc.sampling_jax import get_jaxified_graph
28+
from pymc.sampling.jax import get_jaxified_graph
2929
from pymc.util import RandomSeed, _get_seeds_per_chain, get_default_varnames
3030

3131

pymc_experimental/tests/statespace/test_VARMAX.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -96,6 +96,7 @@ def test_VARMAX_param_counts_match_statsmodels(data, order, var):
9696

9797
@pytest.mark.parametrize("order", orders, ids=ids)
9898
@pytest.mark.filterwarnings("ignore::statsmodels.tools.sm_exceptions.EstimationWarning")
99+
@pytest.mark.filterwarnings("ignore::FutureWarning")
99100
def test_VARMAX_update_matches_statsmodels(data, order, rng):
100101
p, q = order
101102

pymc_experimental/tests/statespace/utilities/test_helpers.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -227,7 +227,7 @@ def simulate_from_numpy_model(mod, rng, param_dict, steps=100):
227227
y = np.zeros(steps)
228228

229229
x[0] = x0
230-
y[0] = Z @ x0
230+
y[0] = (Z @ x0).squeeze()
231231

232232
if not np.allclose(H, 0):
233233
y[0] += rng.multivariate_normal(mean=np.zeros(1), cov=H)
@@ -245,7 +245,7 @@ def simulate_from_numpy_model(mod, rng, param_dict, steps=100):
245245
error = 0
246246

247247
x[t] = c + T @ x[t - 1] + innov
248-
y[t] = d + Z @ x[t] + error
248+
y[t] = (d + Z @ x[t] + error).squeeze()
249249

250250
return x, y
251251

pymc_experimental/tests/test_pathfinder.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,9 @@
2323

2424
# TODO: Remove this filterwarning after pytensor uses jnp.prod instead of jnp.product
2525
@pytest.mark.skipif(sys.platform == "win32", reason="JAX not supported on windows.")
26+
@pytest.mark.skipif(
27+
sys.version_info < (3, 10), reason="pymc.sampling.jax does not currently support python < 3.10"
28+
)
2629
@pytest.mark.filterwarnings("ignore::DeprecationWarning")
2730
def test_pathfinder():
2831
# Data of the Eight Schools Model

pyproject.toml

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,14 @@ addopts = [
77
"--ignore=pymc_experimental/model_builder.py"
88
]
99

10+
filterwarnings =[
11+
"error",
12+
# Raised by arviz when the model_builder class adds non-standard group names to InferenceData
13+
"ignore::UserWarning:arviz.data.inference_data",
14+
15+
# bool8, find_common_type, cumproduct, and product had deprecation warnings added in numpy 1.25
16+
'ignore:.*(\b(pkg_resources\.declare_namespace|np\.bool8|np\.find_common_type|cumproduct|product)\b).*:DeprecationWarning',
17+
]
1018

1119
[tool.black]
1220
line-length = 100
@@ -20,6 +28,7 @@ exclude_lines = [
2028

2129
[tool.isort]
2230
profile = "black"
31+
# lines_between_types = 1
2332

2433
[tool.nbqa.mutate]
2534
isort = 1

pytest.ini

Lines changed: 0 additions & 7 deletions
This file was deleted.

setup.cfg

Lines changed: 0 additions & 11 deletions
This file was deleted.

0 commit comments

Comments
 (0)