diff --git a/conda-envs/environment-test.yml b/conda-envs/environment-test.yml index a1aa85c7e..450b46e30 100644 --- a/conda-envs/environment-test.yml +++ b/conda-envs/environment-test.yml @@ -3,7 +3,7 @@ channels: - conda-forge - nodefaults dependencies: -- pymc>=5.20 +- pymc>=5.21 - pytest-cov>=2.5 - pytest>=3.0 - dask diff --git a/conda-envs/windows-environment-test.yml b/conda-envs/windows-environment-test.yml index 1d1eb7745..d2a3e8934 100644 --- a/conda-envs/windows-environment-test.yml +++ b/conda-envs/windows-environment-test.yml @@ -10,7 +10,7 @@ dependencies: - xhistogram - statsmodels - numba<=0.60.0 -- pymc>=5.20 +- pymc>=5.21 - pip: - blackjax - scikit-learn diff --git a/pymc_extras/statespace/core/statespace.py b/pymc_extras/statespace/core/statespace.py index 2590dd53a..d5d131d80 100644 --- a/pymc_extras/statespace/core/statespace.py +++ b/pymc_extras/statespace/core/statespace.py @@ -707,7 +707,7 @@ def _insert_random_variables(self): with pymc_model: for param_name in self.param_names: param = getattr(pymc_model, param_name, None) - if param: + if param is not None: found_params.append(param.name) missing_params = list(set(self.param_names) - set(found_params)) @@ -746,7 +746,7 @@ def _insert_data_variables(self): with pymc_model: for data_name in data_names: data = getattr(pymc_model, data_name, None) - if data: + if data is not None: found_data.append(data.name) missing_data = list(set(data_names) - set(found_data)) diff --git a/pymc_extras/statespace/filters/distributions.py b/pymc_extras/statespace/filters/distributions.py index d3b70c847..3d0ed44d6 100644 --- a/pymc_extras/statespace/filters/distributions.py +++ b/pymc_extras/statespace/filters/distributions.py @@ -62,29 +62,8 @@ class MvNormalSVD(MvNormal): A JAX MvNormal robust to low-rank covariance matrices """ - rv_op = MvNormalSVDRV() - - -try: - import jax.random - - from pytensor.link.jax.dispatch.random import jax_sample_fn - - @jax_sample_fn.register(MvNormalSVDRV) - def jax_sample_fn_mvnormal_svd(op, node): - def sample_fn(rng, size, dtype, *parameters): - rng_key = rng["jax_state"] - rng_key, sampling_key = jax.random.split(rng_key, 2) - sample = jax.random.multivariate_normal( - sampling_key, *parameters, shape=size, dtype=dtype, method="svd" - ) - rng["jax_state"] = rng_key - return (rng, sample) - - return sample_fn - -except ImportError: - pass + # TODO: Remove this entirely on next PyMC release; method will be exposed directly in MvNormal + rv_op = MvNormalSVDRV(method="svd") class LinearGaussianStateSpaceRV(SymbolicRandomVariable): diff --git a/pyproject.toml b/pyproject.toml index 17187a524..e16b3b389 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -20,6 +20,9 @@ filterwarnings =[ # Warning coming from blackjax 'ignore:jax\.tree_map is deprecated:DeprecationWarning', + + # Ignore PyMC use of numpy.core + 'ignore:numpy\.core\.numeric is deprecated:DeprecationWarning', ] [tool.coverage.report] diff --git a/requirements.txt b/requirements.txt index a4f00ee21..3a1f85ac8 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,3 +1,3 @@ -pymc>=5.20 +pymc>=5.21 scikit-learn better-optimize