From fc474f7342e21cdc8191578d1acbff014eb75962 Mon Sep 17 00:00:00 2001 From: theorashid Date: Tue, 23 Jul 2024 14:20:57 +0100 Subject: [PATCH 1/2] update pre-commit-config replacing black and pylint with ruff --- .pre-commit-config.yaml | 15 +++++---------- 1 file changed, 5 insertions(+), 10 deletions(-) diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 70fe11db9..7e6adb1c9 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -20,17 +20,12 @@ repos: hooks: - id: pyupgrade args: [--py37-plus] -- repo: https://github.com/psf/black - rev: 24.1.1 +- repo: https://github.com/astral-sh/ruff-pre-commit + rev: v0.5.4 hooks: - - id: black - - id: black-jupyter -- repo: https://github.com/PyCQA/pylint - rev: v3.0.3 - hooks: - - id: pylint - args: [--rcfile=.pylintrc] - files: ^pymc_experimental/ + - id: ruff + args: ["--fix", "--output-format=full"] + - id: ruff-format - repo: https://github.com/MarcoGorelli/madforhooks rev: 0.4.1 hooks: From ec40eb77cd0aad2c89f6cbab54bbde51559eaf45 Mon Sep 17 00:00:00 2001 From: theorashid Date: Tue, 23 Jul 2024 15:06:09 +0100 Subject: [PATCH 2/2] run "pre-commit run --all-files" and then manually fix instead of --unsafe-fixes --- conftest.py | 4 +- docs/conf.py | 4 +- pymc_experimental/__init__.py | 19 +- pymc_experimental/distributions/continuous.py | 12 +- pymc_experimental/distributions/discrete.py | 16 +- .../distributions/multivariate/__init__.py | 2 + .../distributions/multivariate/r2d2m2cp.py | 27 ++- pymc_experimental/distributions/timeseries.py | 21 +- pymc_experimental/gp/__init__.py | 2 + pymc_experimental/gp/latent_approx.py | 14 +- pymc_experimental/inference/__init__.py | 2 + pymc_experimental/inference/fit.py | 3 +- pymc_experimental/inference/laplace.py | 8 +- pymc_experimental/inference/pathfinder.py | 4 +- pymc_experimental/inference/smc/sampling.py | 26 +-- pymc_experimental/linearmodel.py | 14 +- pymc_experimental/model/marginal_model.py | 103 +++++++--- .../model/transforms/autoreparam.py | 10 +- pymc_experimental/model_builder.py | 32 +++- .../statespace/core/representation.py | 20 +- .../statespace/core/statespace.py | 179 +++++++++++++----- .../statespace/filters/distributions.py | 80 ++++++-- .../statespace/filters/kalman_filter.py | 52 +++-- .../statespace/filters/kalman_smoother.py | 13 +- .../statespace/models/SARIMAX.py | 70 ++++--- pymc_experimental/statespace/models/VARMAX.py | 26 ++- .../statespace/models/structural.py | 121 +++++++++--- .../statespace/models/utilities.py | 4 +- .../statespace/utils/constants.py | 10 +- .../statespace/utils/data_tools.py | 10 +- pymc_experimental/utils/linear_cg.py | 34 +++- pymc_experimental/utils/pivoted_cholesky.py | 4 +- pymc_experimental/utils/prior.py | 6 +- pymc_experimental/utils/spline.py | 8 +- pymc_experimental/version.py | 4 +- setup.py | 9 +- setupegg.py | 2 - tests/distributions/__init__.py | 2 + tests/distributions/test_continuous.py | 13 +- tests/distributions/test_discrete.py | 10 +- .../test_discrete_markov_chain.py | 47 +++-- tests/distributions/test_multivariate.py | 33 +++- tests/model/test_marginal_model.py | 173 +++++++++++------ tests/model/transforms/test_autoreparam.py | 10 +- tests/statespace/test_SARIMAX.py | 109 +++++++---- tests/statespace/test_VARMAX.py | 62 ++++-- tests/statespace/test_coord_assignment.py | 28 ++- tests/statespace/test_distributions.py | 59 +++--- tests/statespace/test_kalman_filter.py | 32 +++- tests/statespace/test_representation.py | 18 +- tests/statespace/test_statespace.py | 77 +++++--- tests/statespace/test_statespace_JAX.py | 32 ++-- tests/statespace/test_structural.py | 121 ++++++++---- tests/statespace/utilities/test_helpers.py | 16 +- tests/test_blackjax_smc.py | 34 ++-- tests/test_histogram_approximation.py | 16 +- tests/test_laplace.py | 12 +- tests/test_linearmodel.py | 37 ++-- tests/test_model_builder.py | 21 +- tests/test_pathfinder.py | 4 +- tests/test_prior_from_trace.py | 13 +- tests/test_splines.py | 10 +- 62 files changed, 1376 insertions(+), 588 deletions(-) diff --git a/conftest.py b/conftest.py index 3178a8fd5..e446d0a1c 100644 --- a/conftest.py +++ b/conftest.py @@ -2,7 +2,9 @@ def pytest_addoption(parser): - parser.addoption("--runslow", action="store_true", default=False, help="run slow tests") + parser.addoption( + "--runslow", action="store_true", default=False, help="run slow tests" + ) def pytest_configure(config): diff --git a/docs/conf.py b/docs/conf.py index 977952642..1d80b34fd 100644 --- a/docs/conf.py +++ b/docs/conf.py @@ -171,7 +171,9 @@ # One entry per manual page. List of tuples # (source start file, name, description, authors, manual section). -man_pages = [(master_doc, "pymc_experimental", "pymc_experimental Documentation", [author], 1)] +man_pages = [ + (master_doc, "pymc_experimental", "pymc_experimental Documentation", [author], 1) +] # -- Options for Texinfo output ---------------------------------------------- diff --git a/pymc_experimental/__init__.py b/pymc_experimental/__init__.py index 7b9cf7bb6..196716788 100644 --- a/pymc_experimental/__init__.py +++ b/pymc_experimental/__init__.py @@ -13,6 +13,10 @@ # limitations under the License. import logging +from pymc_experimental import distributions, gp, statespace, utils +from pymc_experimental.inference.fit import fit +from pymc_experimental.model.marginal_model import MarginalModel +from pymc_experimental.model.model_api import as_model from pymc_experimental.version import __version__ _log = logging.getLogger("pmx") @@ -23,7 +27,14 @@ handler = logging.StreamHandler() _log.addHandler(handler) -from pymc_experimental import distributions, gp, statespace, utils -from pymc_experimental.inference.fit import fit -from pymc_experimental.model.marginal_model import MarginalModel -from pymc_experimental.model.model_api import as_model + +__all__ = [ + "__version__", + "distributions", + "gp", + "statespace", + "utils", + "fit", + "MarginalModel", + "as_model", +] diff --git a/pymc_experimental/distributions/continuous.py b/pymc_experimental/distributions/continuous.py index 6c2a57002..6cd8ccfaf 100644 --- a/pymc_experimental/distributions/continuous.py +++ b/pymc_experimental/distributions/continuous.py @@ -41,7 +41,9 @@ class GenExtremeRV(RandomVariable): dtype: str = "floatX" _print_name: Tuple[str, str] = ("Generalized Extreme Value", "\\operatorname{GEV}") - def __call__(self, mu=0.0, sigma=1.0, xi=0.0, size=None, **kwargs) -> TensorVariable: + def __call__( + self, mu=0.0, sigma=1.0, xi=0.0, size=None, **kwargs + ) -> TensorVariable: return super().__call__(mu, sigma, xi, size=size, **kwargs) @classmethod @@ -54,7 +56,9 @@ def rng_fn( size: Tuple[int, ...], ) -> np.ndarray: # Notice negative here, since remainder of GenExtreme is based on Coles parametrization - return stats.genextreme.rvs(c=-xi, loc=mu, scale=sigma, random_state=rng, size=size) + return stats.genextreme.rvs( + c=-xi, loc=mu, scale=sigma, random_state=rng, size=size + ) gev = GenExtremeRV() @@ -214,7 +218,9 @@ def support_point(rv, size, mu, sigma, xi): r""" Using the mode, as the mean can be infinite when :math:`\xi > 1` """ - mode = pt.switch(pt.isclose(xi, 0), mu, mu + sigma * (pt.pow(1 + xi, -xi) - 1) / xi) + mode = pt.switch( + pt.isclose(xi, 0), mu, mu + sigma * (pt.pow(1 + xi, -xi) - 1) / xi + ) if not rv_size_is_none(size): mode = pt.full(size, mode) return mode diff --git a/pymc_experimental/distributions/discrete.py b/pymc_experimental/distributions/discrete.py index 3934baa88..4a282c9f2 100644 --- a/pymc_experimental/distributions/discrete.py +++ b/pymc_experimental/distributions/discrete.py @@ -51,14 +51,14 @@ def rng_fn(cls, rng, theta, lam, size): x = np.empty(dist_size) idxs_mask = np.broadcast_to(lam < 0, dist_size) if np.any(idxs_mask): - x[idxs_mask] = cls._inverse_rng_fn(rng, theta, lam, dist_size, idxs_mask=idxs_mask)[ - idxs_mask - ] + x[idxs_mask] = cls._inverse_rng_fn( + rng, theta, lam, dist_size, idxs_mask=idxs_mask + )[idxs_mask] idxs_mask = ~idxs_mask if np.any(idxs_mask): - x[idxs_mask] = cls._branching_rng_fn(rng, theta, lam, dist_size, idxs_mask=idxs_mask)[ - idxs_mask - ] + x[idxs_mask] = cls._branching_rng_fn( + rng, theta, lam, dist_size, idxs_mask=idxs_mask + )[idxs_mask] return x @classmethod @@ -159,7 +159,9 @@ def support_point(rv, size, mu, lam): def logp(value, mu, lam): mu_lam_value = mu + lam * value - logprob = np.log(mu) + logpow(mu_lam_value, value - 1) - mu_lam_value - factln(value) + logprob = ( + np.log(mu) + logpow(mu_lam_value, value - 1) - mu_lam_value - factln(value) + ) # Probability is 0 when value > m, where m is the largest positive integer for # which mu + m * lam > 0 (when lam < 0). diff --git a/pymc_experimental/distributions/multivariate/__init__.py b/pymc_experimental/distributions/multivariate/__init__.py index 64a79b248..12f6b493f 100644 --- a/pymc_experimental/distributions/multivariate/__init__.py +++ b/pymc_experimental/distributions/multivariate/__init__.py @@ -1 +1,3 @@ from pymc_experimental.distributions.multivariate.r2d2m2cp import R2D2M2CP + +__all__ = ["R2D2M2CP"] diff --git a/pymc_experimental/distributions/multivariate/r2d2m2cp.py b/pymc_experimental/distributions/multivariate/r2d2m2cp.py index f6214a7bd..f30f08746 100644 --- a/pymc_experimental/distributions/multivariate/r2d2m2cp.py +++ b/pymc_experimental/distributions/multivariate/r2d2m2cp.py @@ -92,7 +92,9 @@ def _R2D2M2CP_beta( raw = pt.zeros_like(mu_param) else: raw = pm.Normal("raw", dims=dims) - beta = pm.Deterministic(name, (raw * std_param + mu_param) / input_sigma, dims=dims) + beta = pm.Deterministic( + name, (raw * std_param + mu_param) / input_sigma, dims=dims + ) else: if psi_mask is not None and psi_mask.any(): # limit case where some probs are not 1 or 0 @@ -113,7 +115,9 @@ def _R2D2M2CP_beta( # all variables are deterministic beta = pm.Deterministic(name, (mu_param / input_sigma), dims=dims) else: - beta = pm.Normal(name, mu_param / input_sigma, std_param / input_sigma, dims=dims) + beta = pm.Normal( + name, mu_param / input_sigma, std_param / input_sigma, dims=dims + ) return beta @@ -137,7 +141,8 @@ def _psi_masked( dims: Sequence[str], ) -> Tuple[Union[pt.TensorLike, None], pt.TensorVariable]: if not ( - isinstance(positive_probs, pt.Constant) and isinstance(positive_probs_std, pt.Constant) + isinstance(positive_probs, pt.Constant) + and isinstance(positive_probs_std, pt.Constant) ): raise TypeError( "Only constant values for positive_probs and positive_probs_std are accepted" @@ -147,7 +152,9 @@ def _psi_masked( ) mask = ~np.bitwise_or(positive_probs == 1, positive_probs == 0) if np.bitwise_and(~mask, positive_probs_std != 0).any(): - raise ValueError("Can't have both positive_probs == '1 or 0' and positive_probs_std != 0") + raise ValueError( + "Can't have both positive_probs == '1 or 0' and positive_probs_std != 0" + ) if (~mask).any() and mask.any(): # limit case where some probs are not 1 or 0 # setsubtensor is required @@ -206,7 +213,9 @@ def _phi( if variance_explained is not None: raise TypeError("Can't use variable importance with variance explained") if len(model.coords[dim]) <= 1: - raise TypeError("Can't use variable importance with less than two variables") + raise TypeError( + "Can't use variable importance with less than two variables" + ) variables_importance = pt.as_tensor(variables_importance) if importance_concentration is not None: variables_importance *= importance_concentration @@ -218,7 +227,9 @@ def _phi( else: phi = _broadcast_as_dims(1.0, dims=dims) if importance_concentration is not None: - return pm.Dirichlet("phi", importance_concentration * phi, dims=broadcast_dims + [dim]) + return pm.Dirichlet( + "phi", importance_concentration * phi, dims=broadcast_dims + [dim] + ) else: return phi @@ -428,7 +439,9 @@ def R2D2M2CP( dims=dims, ) mask, psi = _psi( - positive_probs=positive_probs, positive_probs_std=positive_probs_std, dims=dims + positive_probs=positive_probs, + positive_probs_std=positive_probs_std, + dims=dims, ) beta = _R2D2M2CP_beta( diff --git a/pymc_experimental/distributions/timeseries.py b/pymc_experimental/distributions/timeseries.py index 91da141ac..86a059fae 100644 --- a/pymc_experimental/distributions/timeseries.py +++ b/pymc_experimental/distributions/timeseries.py @@ -26,7 +26,9 @@ from pytensor.tensor.random.op import RandomVariable -def _make_outputs_info(n_lags: int, init_dist: Distribution) -> List[Union[Distribution, dict]]: +def _make_outputs_info( + n_lags: int, init_dist: Distribution +) -> List[Union[Distribution, dict]]: """ Two cases are needed for outputs_info in the scans used by DiscreteMarkovRv. If n_lags = 1, we need to throw away the first dimension of init_dist_ or else markov_chain will have shape (steps, 1, *batch_size) instead of @@ -124,7 +126,9 @@ def __new__(cls, *args, steps=None, n_lags=1, **kwargs): @classmethod def dist(cls, P=None, logit_P=None, steps=None, init_dist=None, n_lags=1, **kwargs): steps = get_support_shape_1d( - support_shape=steps, shape=kwargs.get("shape", None), support_shape_offset=n_lags + support_shape=steps, + shape=kwargs.get("shape", None), + support_shape_offset=n_lags, ) if steps is None: @@ -199,7 +203,9 @@ def transition(*args): (state_next_rng,) = tuple(state_updates.values()) - discrete_mc_ = pt.moveaxis(pt.concatenate([init_dist_, markov_chain], axis=0), 0, -1) + discrete_mc_ = pt.moveaxis( + pt.concatenate([init_dist_, markov_chain], axis=0), 0, -1 + ) discrete_mc_op = DiscreteMarkovChainRV( inputs=[P_, steps_, init_dist_, state_rng], @@ -218,7 +224,9 @@ def change_mc_size(op, dist, new_size, expand=False): old_size = dist.shape[:-1] new_size = tuple(new_size) + tuple(old_size) - return DiscreteMarkovChain.rv_op(*dist.owner.inputs[:-1], size=new_size, n_lags=op.n_lags) + return DiscreteMarkovChain.rv_op( + *dist.owner.inputs[:-1], size=new_size, n_lags=op.n_lags + ) @_support_point.register(DiscreteMarkovChainRV) @@ -247,7 +255,10 @@ def discrete_mc_logp(op, values, P, steps, init_dist, state_rng, **kwargs): value = values[0] n_lags = op.n_lags - indexes = [value[..., i : -(n_lags - i) if n_lags != i else None] for i in range(n_lags + 1)] + indexes = [ + value[..., i : -(n_lags - i) if n_lags != i else None] + for i in range(n_lags + 1) + ] mc_logprob = logp(init_dist, value[..., :n_lags]).sum(axis=-1) mc_logprob += pt.log(P[tuple(indexes)]).sum(axis=-1) diff --git a/pymc_experimental/gp/__init__.py b/pymc_experimental/gp/__init__.py index c8804dd4e..ae827e947 100644 --- a/pymc_experimental/gp/__init__.py +++ b/pymc_experimental/gp/__init__.py @@ -14,3 +14,5 @@ from pymc_experimental.gp.latent_approx import KarhunenLoeveExpansion, ProjectedProcess + +__all__ = ["KarhunenLoeveExpansion", "ProjectedProcess"] diff --git a/pymc_experimental/gp/latent_approx.py b/pymc_experimental/gp/latent_approx.py index ddcbb845d..29579ed6f 100644 --- a/pymc_experimental/gp/latent_approx.py +++ b/pymc_experimental/gp/latent_approx.py @@ -47,7 +47,9 @@ def _build_prior(self, name, X, X_inducing, jitter=JITTER_DEFAULT, **kwargs): L = cholesky(stabilize(Kuu, jitter)) n_inducing_points = np.shape(X_inducing)[0] - v = pm.Normal(name + "_u_rotated_", mu=0.0, sigma=1.0, size=n_inducing_points, **kwargs) + v = pm.Normal( + name + "_u_rotated_", mu=0.0, sigma=1.0, size=n_inducing_points, **kwargs + ) u = pm.Deterministic(name + "_u", L @ v) Kfu = self.cov_func(X, X_inducing) @@ -111,7 +113,9 @@ def _build_conditional(self, name, Xnew, X_inducing, L, Kuuiu, jitter, **kwargs) Ksu = self.cov_func(Xnew, X_inducing) mu = self.mean_func(Xnew) + Ksu @ Kuuiu tmp = solve_lower(L, pt.transpose(Ksu)) - Qss = pt.transpose(tmp) @ tmp # Qss = tt.dot(tt.dot(Ksu, tt.nlinalg.pinv(Kuu)), Ksu.T) + Qss = ( + pt.transpose(tmp) @ tmp + ) # Qss = tt.dot(tt.dot(Ksu, tt.nlinalg.pinv(Kuu)), Ksu.T) Kss = self.cov_func(Xnew) Lss = cholesky(stabilize(Kss - Qss, jitter)) return mu, Lss @@ -137,7 +141,7 @@ def __init__( super().__init__(mean_func=mean_func, cov_func=cov_func) def _build_prior(self, name, X, jitter=1e-6, **kwargs): - mu = self.mean_func(X) + # mu = self.mean_func(X) Kxx = pm.gp.util.stabilize(self.cov_func(X), jitter) vals, vecs = pt.linalg.eigh(Kxx) ## NOTE: REMOVED PRECISION CUTOFF @@ -147,7 +151,9 @@ def _build_prior(self, name, X, jitter=1e-6, **kwargs): if self.variance_limit == 1: n_eigs = len(vals) else: - n_eigs = ((vals[::-1].cumsum() / vals.sum()) > self.variance_limit).nonzero()[0][0] + n_eigs = ( + (vals[::-1].cumsum() / vals.sum()) > self.variance_limit + ).nonzero()[0][0] U = vecs[:, -n_eigs:] s = vals[-n_eigs:] basis = U * pt.sqrt(s) diff --git a/pymc_experimental/inference/__init__.py b/pymc_experimental/inference/__init__.py index c74607bf5..8b5dbe189 100644 --- a/pymc_experimental/inference/__init__.py +++ b/pymc_experimental/inference/__init__.py @@ -14,3 +14,5 @@ from pymc_experimental.inference.fit import fit + +__all__ = ["fit"] diff --git a/pymc_experimental/inference/fit.py b/pymc_experimental/inference/fit.py index 71dfb9f8b..4f2f8fbfa 100644 --- a/pymc_experimental/inference/fit.py +++ b/pymc_experimental/inference/fit.py @@ -31,7 +31,7 @@ def fit(method, **kwargs): """ if method == "pathfinder": try: - import blackjax + import blackjax # noqa: F401 except ImportError as exc: raise RuntimeError("Need BlackJAX to use `pathfinder`") from exc @@ -40,7 +40,6 @@ def fit(method, **kwargs): return fit_pathfinder(**kwargs) if method == "laplace": - from pymc_experimental.inference.laplace import laplace return laplace(**kwargs) diff --git a/pymc_experimental/inference/laplace.py b/pymc_experimental/inference/laplace.py index 1508b6e88..60c307ab1 100644 --- a/pymc_experimental/inference/laplace.py +++ b/pymc_experimental/inference/laplace.py @@ -146,11 +146,15 @@ def addFitToInferenceData(vars, idata, mean, covariance): # Convert to xarray DataArray mean_dataarray = xr.DataArray(mean, dims=["rows"], coords={"rows": coord_names}) cov_dataarray = xr.DataArray( - covariance, dims=["rows", "columns"], coords={"rows": coord_names, "columns": coord_names} + covariance, + dims=["rows", "columns"], + coords={"rows": coord_names, "columns": coord_names}, ) # Create xarray dataset - dataset = xr.Dataset({"mean_vector": mean_dataarray, "covariance_matrix": cov_dataarray}) + dataset = xr.Dataset( + {"mean_vector": mean_dataarray, "covariance_matrix": cov_dataarray} + ) idata.add_groups(fit=dataset) diff --git a/pymc_experimental/inference/pathfinder.py b/pymc_experimental/inference/pathfinder.py index 5e5533dd6..c507b2a5d 100644 --- a/pymc_experimental/inference/pathfinder.py +++ b/pymc_experimental/inference/pathfinder.py @@ -48,7 +48,9 @@ def convert_flat_trace_to_idata( trace = {k: np.asarray(v)[None, ...] for k, v in trace.items()} var_names = model.unobserved_value_vars - vars_to_sample = list(get_default_varnames(var_names, include_transformed=include_transformed)) + vars_to_sample = list( + get_default_varnames(var_names, include_transformed=include_transformed) + ) print("Transforming variables...", file=sys.stdout) jax_fn = get_jaxified_graph(inputs=model.value_vars, outputs=vars_to_sample) result = jax.vmap(jax.vmap(jax_fn))( diff --git a/pymc_experimental/inference/smc/sampling.py b/pymc_experimental/inference/smc/sampling.py index 93f8a8a32..954ab5c87 100644 --- a/pymc_experimental/inference/smc/sampling.py +++ b/pymc_experimental/inference/smc/sampling.py @@ -198,7 +198,9 @@ def arviz_from_particles(model, particles): ------- """ n_particles = jax.tree_util.tree_flatten(particles)[0][0].shape[0] - by_varname = {k.name: v.squeeze()[np.newaxis, :] for k, v in zip(model.value_vars, particles)} + by_varname = { + k.name: v.squeeze()[np.newaxis, :] for k, v in zip(model.value_vars, particles) + } varnames = [v.name for v in model.value_vars] with model: strace = NDArray(name=model.name) @@ -344,18 +346,18 @@ def add_to_inference_data( "sampler": f"Blackjax SMC with {kernel} kernel", } - inference_data.posterior.attrs["lambda_evolution"] = np.array(diagnosis.lmbda_evolution)[ - :iterations_to_diagnose - ] + inference_data.posterior.attrs["lambda_evolution"] = np.array( + diagnosis.lmbda_evolution + )[:iterations_to_diagnose] inference_data.posterior.attrs["log_likelihood_increments"] = np.array( diagnosis.log_likelihood_increment_evolution )[:iterations_to_diagnose] - inference_data.posterior.attrs["ancestors_evolution"] = np.array(diagnosis.ancestors_evolution)[ - :iterations_to_diagnose - ] - inference_data.posterior.attrs["weights_evolution"] = np.array(diagnosis.weights_evolution)[ - :iterations_to_diagnose - ] + inference_data.posterior.attrs["ancestors_evolution"] = np.array( + diagnosis.ancestors_evolution + )[:iterations_to_diagnose] + inference_data.posterior.attrs["weights_evolution"] = np.array( + diagnosis.weights_evolution + )[:iterations_to_diagnose] for k in experiment_parameters: inference_data.posterior.attrs[k] = experiment_parameters[k] @@ -391,7 +393,9 @@ def logp_fn_wrap(particles): def initialize_population(model, draws, random_seed) -> Dict[str, np.ndarray]: with warnings.catch_warnings(): - warnings.filterwarnings("ignore", category=UserWarning, message="The effect of Potentials") + warnings.filterwarnings( + "ignore", category=UserWarning, message="The effect of Potentials" + ) prior_expression = make_initial_point_expression( free_rvs=model.free_RVs, diff --git a/pymc_experimental/linearmodel.py b/pymc_experimental/linearmodel.py index 0c4237dab..d2dd94258 100644 --- a/pymc_experimental/linearmodel.py +++ b/pymc_experimental/linearmodel.py @@ -8,7 +8,9 @@ class LinearModel(ModelBuilder): - def __init__(self, model_config: Dict = None, sampler_config: Dict = None, nsamples=100): + def __init__( + self, model_config: Dict = None, sampler_config: Dict = None, nsamples=100 + ): self.nsamples = nsamples super().__init__(model_config, sampler_config) @@ -80,10 +82,12 @@ def build_model(self, X: pd.DataFrame, y: pd.Series): obs_error = pm.HalfNormal("σ_model_fmc", cfg["obs_error"]) # Model - y_model = pm.Deterministic("y_model", intercept + slope * x, dims="observation") + y_model = pm.Deterministic( + "y_model", intercept + slope * x, dims="observation" + ) # observed data - y_hat = pm.Normal( + pm.Normal( "y_hat", y_model, sigma=obs_error, @@ -94,7 +98,9 @@ def build_model(self, X: pd.DataFrame, y: pd.Series): self._data_setter(X, y) - def _data_setter(self, X: pd.DataFrame, y: Optional[Union[pd.DataFrame, pd.Series]] = None): + def _data_setter( + self, X: pd.DataFrame, y: Optional[Union[pd.DataFrame, pd.Series]] = None + ): with self.model: pm.set_data({"x": X.squeeze()}) if y is not None: diff --git a/pymc_experimental/model/marginal_model.py b/pymc_experimental/model/marginal_model.py index ead9a362b..e8988ee05 100644 --- a/pymc_experimental/model/marginal_model.py +++ b/pymc_experimental/model/marginal_model.py @@ -18,6 +18,7 @@ from pytensor import Mode, scan from pytensor.compile import SharedVariable from pytensor.graph import Constant, FunctionGraph, ancestors, clone_replace +from pytensor.graph.basic import graph_inputs from pytensor.graph.replace import graph_replace, vectorize_graph from pytensor.scan import map as scan_map from pytensor.tensor import TensorType, TensorVariable @@ -106,7 +107,9 @@ def _delete_rv_mappings(self, rv: TensorVariable) -> None: else: self.observed_RVs.remove(rv) - def _transfer_rv_mappings(self, old_rv: TensorVariable, new_rv: TensorVariable) -> None: + def _transfer_rv_mappings( + self, old_rv: TensorVariable, new_rv: TensorVariable + ) -> None: """Transfer model mappings from old_rv to new_rv""" assert old_rv in self.basic_RVs, "old_rv is not part of the Model" @@ -188,7 +191,8 @@ def _marginalize(self, user_warnings=False): if isinstance(transform, IntervalTransform) or ( isinstance(transform, Chain) and any( - isinstance(tr, IntervalTransform) for tr in transform.transform_list + isinstance(tr, IntervalTransform) + for tr in transform.transform_list ) ): warnings.warn( @@ -219,8 +223,12 @@ def from_model(model: Union[Model, "MarginalModel"]) -> "MarginalModel": marginalized_rvs = [] marginalized_named_vars_to_dims = {} - model_vars = model.basic_RVs + model.potentials + model.deterministics + marginalized_rvs - data_vars = [var for name, var in model.named_vars.items() if var not in model_vars] + model_vars = ( + model.basic_RVs + model.potentials + model.deterministics + marginalized_rvs + ) + data_vars = [ + var for name, var in model.named_vars.items() if var not in model_vars + ] vars = model_vars + data_vars cloned_vars = clone_replace(vars) vars_to_clone = {var: cloned_var for var, cloned_var in zip(vars, cloned_vars)} @@ -230,8 +238,12 @@ def from_model(model: Union[Model, "MarginalModel"]) -> "MarginalModel": {name: vars_to_clone[var] for name, var in model.named_vars.items()} ) new_model.named_vars_to_dims = model.named_vars_to_dims - new_model.values_to_rvs = {vv: vars_to_clone[rv] for vv, rv in model.values_to_rvs.items()} - new_model.rvs_to_values = {vars_to_clone[rv]: vv for rv, vv in model.rvs_to_values.items()} + new_model.values_to_rvs = { + vv: vars_to_clone[rv] for vv, rv in model.values_to_rvs.items() + } + new_model.rvs_to_values = { + vars_to_clone[rv]: vv for rv, vv in model.rvs_to_values.items() + } new_model.rvs_to_transforms = { vars_to_clone[rv]: tr for rv, tr in model.rvs_to_transforms.items() } @@ -376,7 +388,9 @@ def recover_marginals( var_names = [var if isinstance(var, str) else var.name for var in var_names] vars_to_recover = [v for v in self.marginalized_rvs if v.name in var_names] - missing_names = [v.name for v in vars_to_recover if v not in self.marginalized_rvs] + missing_names = [ + v.name for v in vars_to_recover if v not in self.marginalized_rvs + ] if missing_names: raise ValueError(f"Unrecognized var_names: {missing_names}") @@ -393,7 +407,9 @@ def recover_marginals( ] sample_dims = ("chain", "draw") - posterior_pts, stacked_dims = dataset_to_point_list(posterior_values, sample_dims) + posterior_pts, stacked_dims = dataset_to_point_list( + posterior_values, sample_dims + ) # Handle Transforms transform_fn, transform_names = self._to_transformed() @@ -416,7 +432,9 @@ def transform_input(inputs): m = self.clone() marginalized_rv = m.vars_to_clone[marginalized_rv] m.unmarginalize([marginalized_rv]) - dependent_vars = find_conditional_dependent_rvs(marginalized_rv, m.basic_RVs) + dependent_vars = find_conditional_dependent_rvs( + marginalized_rv, m.basic_RVs + ) joint_logps = m.logp(vars=[marginalized_rv] + dependent_vars, sum=False) marginalized_value = m.rvs_to_values[marginalized_rv] @@ -428,7 +446,9 @@ def transform_input(inputs): marginalized_rv.type, dependent_logps ) - rv_shape = constant_fold(tuple(marginalized_rv.shape), raise_not_constant=False) + rv_shape = constant_fold( + tuple(marginalized_rv.shape), raise_not_constant=False + ) rv_domain = get_domain_of_finite_discrete_rv(marginalized_rv) rv_domain_tensor = pt.moveaxis( pt.full( @@ -476,7 +496,8 @@ def transform_input(inputs): logps = np.array(logps) samples = np.array(samples) rv_dict[marginalized_rv.name] = samples.reshape( - tuple(len(coord) for coord in stacked_dims.values()) + samples.shape[1:], + tuple(len(coord) for coord in stacked_dims.values()) + + samples.shape[1:], ) else: logps = np.array(logvs) @@ -485,10 +506,12 @@ def transform_input(inputs): tuple(len(coord) for coord in stacked_dims.values()) + logps.shape[1:], ) if marginalized_rv.name in m.named_vars_to_dims: - rv_dims[marginalized_rv.name] = list(m.named_vars_to_dims[marginalized_rv.name]) - rv_dims["lp_" + marginalized_rv.name] = rv_dims[marginalized_rv.name] + [ - "lp_" + marginalized_rv.name + "_dim" - ] + rv_dims[marginalized_rv.name] = list( + m.named_vars_to_dims[marginalized_rv.name] + ) + rv_dims["lp_" + marginalized_rv.name] = rv_dims[ + marginalized_rv.name + ] + ["lp_" + marginalized_rv.name + "_dim"] coords, dims = coords_and_dims_for_inferencedata(self) dims.update(rv_dims) @@ -530,7 +553,9 @@ def marginalize(model: Model, rvs_to_marginalize: ModelRVs) -> MarginalModel: """ if not isinstance(rvs_to_marginalize, tuple | list): rvs_to_marginalize = (rvs_to_marginalize,) - rvs_to_marginalize = [rv if isinstance(rv, str) else rv.name for rv in rvs_to_marginalize] + rvs_to_marginalize = [ + rv if isinstance(rv, str) else rv.name for rv in rvs_to_marginalize + ] marginal_model = MarginalModel.from_model(model) marginal_model.marginalize(rvs_to_marginalize) @@ -589,7 +614,10 @@ def find_conditional_dependent_rvs(dependable_rv, all_rvs): return [ rv for rv in all_rvs - if (rv is not dependable_rv and is_conditional_dependent(rv, dependable_rv, all_rvs)) + if ( + rv is not dependable_rv + and is_conditional_dependent(rv, dependable_rv, all_rvs) + ) ] @@ -628,7 +656,9 @@ def is_elemwise_subgraph(rv_to_marginalize, other_input_rvs, output_rvs): return False # Check that none of the truncated inputs depends on the marginalized_rv - other_truncated_inputs = [inp for inp in truncated_inputs if inp is not rv_to_marginalize] + other_truncated_inputs = [ + inp for inp in truncated_inputs if inp is not rv_to_marginalize + ] # TODO: We don't need to go all the way to the root variables if rv_to_marginalize in ancestors( other_truncated_inputs, blockers=[rv_to_marginalize, *other_input_rvs] @@ -637,12 +667,11 @@ def is_elemwise_subgraph(rv_to_marginalize, other_input_rvs, output_rvs): return True -from pytensor.graph.basic import graph_inputs - - def collect_shared_vars(outputs, blockers): return [ - inp for inp in graph_inputs(outputs, blockers=blockers) if isinstance(inp, SharedVariable) + inp + for inp in graph_inputs(outputs, blockers=blockers) + if isinstance(inp, SharedVariable) ] @@ -662,7 +691,9 @@ def replace_finite_discrete_marginal_subgraph(fgraph, rv_to_marginalize, all_rvs ) [ndim_supp] = ndim_supp if ndim_supp > 0: - raise NotImplementedError("Marginalization with dependent Multivariate RVs not implemented") + raise NotImplementedError( + "Marginalization with dependent Multivariate RVs not implemented" + ) marginalized_rv_input_rvs = find_conditional_input_rvs([rv_to_marginalize], all_rvs) dependent_rvs_input_rvs = [ @@ -677,8 +708,13 @@ def replace_finite_discrete_marginal_subgraph(fgraph, rv_to_marginalize, all_rvs # can ultimately be generated that is proportional to the support domain and not # to the variables dimensions # We don't need to worry about this if the RV is scalar. - if np.prod(constant_fold(tuple(rv_to_marginalize.shape), raise_not_constant=False)) != 1: - if not is_elemwise_subgraph(rv_to_marginalize, dependent_rvs_input_rvs, dependent_rvs): + if ( + np.prod(constant_fold(tuple(rv_to_marginalize.shape), raise_not_constant=False)) + != 1 + ): + if not is_elemwise_subgraph( + rv_to_marginalize, dependent_rvs_input_rvs, dependent_rvs + ): raise NotImplementedError( "The subgraph between a marginalized RV and its dependents includes non Elemwise operations. " "This is currently not supported", @@ -736,7 +772,9 @@ def _add_reduce_batch_dependent_logps( dbcast = dependent_logp.type.broadcastable dim_diff = len(dbcast) - len(mbcast) mbcast_aligned = (True,) * dim_diff + mbcast - vbcast_axis = [i for i, (m, v) in enumerate(zip(mbcast_aligned, dbcast)) if m and not v] + vbcast_axis = [ + i for i, (m, v) in enumerate(zip(mbcast_aligned, dbcast)) if m and not v + ] reduced_logps.append(dependent_logp.sum(vbcast_axis)) return pt.add(*reduced_logps) @@ -768,7 +806,9 @@ def finite_discrete_marginal_rv_logp(op, values, *inputs, **kwargs): # batched dimensions of the marginalized RV # PyMC does not allow RVs in the logp graph, even if we are just using the shape - marginalized_rv_shape = constant_fold(tuple(marginalized_rv.shape), raise_not_constant=False) + marginalized_rv_shape = constant_fold( + tuple(marginalized_rv.shape), raise_not_constant=False + ) marginalized_rv_domain = get_domain_of_finite_discrete_rv(marginalized_rv) marginalized_rv_domain_tensor = pt.moveaxis( pt.full( @@ -787,7 +827,9 @@ def finite_discrete_marginal_rv_logp(op, values, *inputs, **kwargs): except Exception: # Fallback to Scan def logp_fn(marginalized_rv_const, *non_sequences): - return graph_replace(joint_logp, replace={marginalized_vv: marginalized_rv_const}) + return graph_replace( + joint_logp, replace={marginalized_vv: marginalized_rv_const} + ) joint_logps, _ = scan_map( fn=logp_fn, @@ -804,7 +846,6 @@ def logp_fn(marginalized_rv_const, *non_sequences): @_logprob.register(DiscreteMarginalMarkovChainRV) def marginal_hmm_logp(op, values, *inputs, **kwargs): - marginalized_rvs_node = op.make_node(*inputs) inner_rvs = clone_replace( op.inner_outputs, @@ -833,7 +874,9 @@ def marginal_hmm_logp(op, values, *inputs, **kwargs): # Add a batch dimension for the domain of the chain chain_shape = constant_fold(tuple(chain_rv.shape)) batch_chain_value = pt.moveaxis(pt.full((*chain_shape, domain.size), domain), -1, 0) - batch_logp_emissions = vectorize_graph(reduced_logp_emissions, {chain_value: batch_chain_value}) + batch_logp_emissions = vectorize_graph( + reduced_logp_emissions, {chain_value: batch_chain_value} + ) # Step 2: Compute the transition probabilities # This is the "forward algorithm", alpha_t = p(y | s_t) * sum_{s_{t-1}}(p(s_t | s_{t-1}) * alpha_{t-1}) diff --git a/pymc_experimental/model/transforms/autoreparam.py b/pymc_experimental/model/transforms/autoreparam.py index bb3996459..7e78247fc 100644 --- a/pymc_experimental/model/transforms/autoreparam.py +++ b/pymc_experimental/model/transforms/autoreparam.py @@ -42,7 +42,9 @@ class VIP: _logit_lambda: Dict[str, pytensor.tensor.sharedvar.TensorSharedVariable] @property - def variational_parameters(self) -> List[pytensor.tensor.sharedvar.TensorSharedVariable]: + def variational_parameters( + self, + ) -> List[pytensor.tensor.sharedvar.TensorSharedVariable]: r"""Return raw :math:`\operatorname{logit}(\lambda_k)` for custom optimization. Examples @@ -222,7 +224,9 @@ def _( ) -> ModelDeterministic: rng, size, loc, scale = node.inputs if transform is not None: - raise NotImplementedError("Reparametrization of Normal with Transform is not implemented") + raise NotImplementedError( + "Reparametrization of Normal with Transform is not implemented" + ) vip_rv_ = pm.Normal.dist( lam * loc, scale**lam, @@ -418,6 +422,6 @@ def vip_reparametrize( lambda_names.append(lam.name) toposort_replace(fmodel, replacements, reverse=True) reparam_model = model_from_fgraph(fmodel) - model_lambdas = {n: reparam_model[l] for l, n in zip(lambda_names, var_names)} + model_lambdas = {n: reparam_model[lam] for lam, n in zip(lambda_names, var_names)} vip = VIP(model_lambdas) return reparam_model, vip diff --git a/pymc_experimental/model_builder.py b/pymc_experimental/model_builder.py index 6e2a256e6..89626e40d 100644 --- a/pymc_experimental/model_builder.py +++ b/pymc_experimental/model_builder.py @@ -72,19 +72,27 @@ def __init__( >>> model = MyModel(model_config, sampler_config) """ sampler_config = ( - self.get_default_sampler_config() if sampler_config is None else sampler_config + self.get_default_sampler_config() + if sampler_config is None + else sampler_config ) self.sampler_config = sampler_config - model_config = self.get_default_model_config() if model_config is None else model_config + model_config = ( + self.get_default_model_config() if model_config is None else model_config + ) self.model_config = model_config # parameters for priors etc. self.model = None # Set by build_model - self.idata: Optional[az.InferenceData] = None # idata is generated during fitting + self.idata: Optional[az.InferenceData] = ( + None # idata is generated during fitting + ) self.is_fitted_ = False def _validate_data(self, X, y=None): if y is not None: - return check_X_y(X, y, accept_sparse=False, y_numeric=True, multi_output=False) + return check_X_y( + X, y, accept_sparse=False, y_numeric=True, multi_output=False + ) else: return check_array(X, accept_sparse=False) @@ -396,10 +404,14 @@ def _model_config_formatting(cls, model_config: Dict) -> Dict: if isinstance(model_config[key][sub_key], list): # Check if "dims" key to convert it to tuple if sub_key == "dims": - model_config[key][sub_key] = tuple(model_config[key][sub_key]) + model_config[key][sub_key] = tuple( + model_config[key][sub_key] + ) # Convert all other lists to numpy arrays else: - model_config[key][sub_key] = np.array(model_config[key][sub_key]) + model_config[key][sub_key] = np.array( + model_config[key][sub_key] + ) return model_config @classmethod @@ -431,7 +443,9 @@ def load(cls, fname: str): filepath = Path(str(fname)) idata = az.from_netcdf(filepath) # needs to be converted, because json.loads was changing tuple to list - model_config = cls._model_config_formatting(json.loads(idata.attrs["model_config"])) + model_config = cls._model_config_formatting( + json.loads(idata.attrs["model_config"]) + ) model = cls( model_config=model_config, sampler_config=json.loads(idata.attrs["sampler_config"]), @@ -612,7 +626,9 @@ def sample_prior_predictive( else: self.idata = prior_pred - prior_predictive_samples = az.extract(prior_pred, "prior_predictive", combined=combined) + prior_predictive_samples = az.extract( + prior_pred, "prior_predictive", combined=combined + ) return prior_predictive_samples diff --git a/pymc_experimental/statespace/core/representation.py b/pymc_experimental/statespace/core/representation.py index 41c316134..531cf3aec 100644 --- a/pymc_experimental/statespace/core/representation.py +++ b/pymc_experimental/statespace/core/representation.py @@ -223,7 +223,9 @@ def _validate_key(self, key: KeyLike) -> None: if key not in self.shapes: raise IndexError(f"{key} is an invalid state space matrix name") - def _update_shape(self, key: KeyLike, value: Union[np.ndarray, pt.Variable]) -> None: + def _update_shape( + self, key: KeyLike, value: Union[np.ndarray, pt.Variable] + ) -> None: if isinstance(value, (pt.TensorConstant, pt.TensorVariable)): shape = value.type.shape else: @@ -231,7 +233,9 @@ def _update_shape(self, key: KeyLike, value: Union[np.ndarray, pt.Variable]) -> old_shape = self.shapes[key] ndim_core = 1 if key in VECTOR_VALUED else 2 - if not all([a == b for a, b in zip(shape[-ndim_core:], old_shape[-ndim_core:])]): + if not all( + [a == b for a, b in zip(shape[-ndim_core:], old_shape[-ndim_core:])] + ): raise ValueError( f"The last two dimensions of {key} must be {old_shape[-ndim_core:]}, found {shape[-ndim_core:]}" ) @@ -269,7 +273,9 @@ def _validate_key_and_get_type(key: KeyLike) -> Type[str]: return type(key) - def _validate_matrix_shape(self, name: str, X: Union[np.ndarray, pt.TensorVariable]) -> None: + def _validate_matrix_shape( + self, name: str, X: Union[np.ndarray, pt.TensorVariable] + ) -> None: time_dim, *expected_shape = self.shapes[name] expected_shape = tuple(expected_shape) shape = X.shape if isinstance(X, np.ndarray) else X.type.shape @@ -328,7 +334,9 @@ def _validate_matrix_shape(self, name: str, X: Union[np.ndarray, pt.TensorVariab # f"provided data)" # ) - def _check_provided_tensor(self, name: str, X: pt.TensorVariable) -> pt.TensorVariable: + def _check_provided_tensor( + self, name: str, X: pt.TensorVariable + ) -> pt.TensorVariable: self._validate_matrix_shape(name, X) if name not in NEVER_TIME_VARYING: if X.ndim == 1 and name in VECTOR_VALUED: @@ -407,7 +415,9 @@ def __getitem__(self, key: KeyLike) -> pt.TensorVariable: else: raise IndexError("First index must the name of a valid state space matrix.") - def __setitem__(self, key: KeyLike, value: Union[float, int, np.ndarray, pt.Variable]) -> None: + def __setitem__( + self, key: KeyLike, value: Union[float, int, np.ndarray, pt.Variable] + ) -> None: _type = type(key) # Case 1: key is a string: we are setting an entire matrix. diff --git a/pymc_experimental/statespace/core/statespace.py b/pymc_experimental/statespace/core/statespace.py index 9e7b9fa47..e2305e369 100644 --- a/pymc_experimental/statespace/core/statespace.py +++ b/pymc_experimental/statespace/core/statespace.py @@ -64,7 +64,9 @@ def _validate_filter_arg(filter_arg): def _verify_group(group): if group not in ["prior", "posterior"]: - raise ValueError(f'Argument "group" must be one of "prior" or "posterior", found {group}') + raise ValueError( + f'Argument "group" must be one of "prior" or "posterior", found {group}' + ) class PyMCStateSpace: @@ -244,11 +246,14 @@ def __init__( if filter_type.lower() not in FILTER_FACTORY.keys(): raise NotImplementedError( - "The following are valid filter types: " + ", ".join(list(FILTER_FACTORY.keys())) + "The following are valid filter types: " + + ", ".join(list(FILTER_FACTORY.keys())) ) if filter_type == "single" and self.k_endog > 1: - raise ValueError('Cannot use filter_type = "single" with multiple observed time series') + raise ValueError( + 'Cannot use filter_type = "single" with multiple observed time series' + ) self.kalman_filter = FILTER_FACTORY[filter_type.lower()]() self.kalman_smoother = KalmanSmoother() @@ -395,7 +400,9 @@ def observed_states(self) -> list[str]: """ A k_endog length list of strings, associated with the model's observed states """ - raise NotImplementedError("The observed_states property has not been implemented!") + raise NotImplementedError( + "The observed_states property has not been implemented!" + ) @property def shock_names(self) -> list[str]: @@ -413,7 +420,9 @@ def default_priors(self) -> dict[str, Callable]: Returns a dictionary with param_name: Callable key-value pairs. Used by the ``add_default_priors()`` method to automatically add priors to the PyMC model. """ - raise NotImplementedError("The default_priors property has not been implemented!") + raise NotImplementedError( + "The default_priors property has not been implemented!" + ) @property def coords(self) -> dict[str, Sequence[str]]: @@ -442,7 +451,9 @@ def add_default_priors(self) -> None: """ Add default priors to the active PyMC model context """ - raise NotImplementedError("The add_default_priors property has not been implemented!") + raise NotImplementedError( + "The add_default_priors property has not been implemented!" + ) def make_and_register_variable( self, name, shape: int | tuple[int] | None = None, dtype=floatX @@ -584,7 +595,9 @@ def make_symbolic_graph(self) -> None: self.ssm['selection', 1:, 0] = theta_params self.ssm['state_cov', 0, 0] = sigma """ - raise NotImplementedError("The make_symbolic_statespace method has not been implemented!") + raise NotImplementedError( + "The make_symbolic_statespace method has not been implemented!" + ) def _get_matrix_shape_and_dims( self, name: str @@ -694,7 +707,9 @@ def _insert_random_variables(self): matrices = list(self._unpack_statespace_with_placeholders()) - replacement_dict = {var: pymc_model[name] for name, var in self._name_to_variable.items()} + replacement_dict = { + var: pymc_model[name] for name, var in self._name_to_variable.items() + } self.subbed_ssm = graph_replace(matrices, replace=replacement_dict, strict=True) def _insert_data_variables(self): @@ -724,8 +739,12 @@ def _insert_data_variables(self): + ", ".join(missing_data) ) - replacement_dict = {data: pymc_model[name] for name, data in self._name_to_data.items()} - self.subbed_ssm = graph_replace(self.subbed_ssm, replace=replacement_dict, strict=True) + replacement_dict = { + data: pymc_model[name] for name, data in self._name_to_data.items() + } + self.subbed_ssm = graph_replace( + self.subbed_ssm, replace=replacement_dict, strict=True + ) def _register_matrices_with_pymc_model(self) -> list[pt.TensorVariable]: """ @@ -759,7 +778,9 @@ def _register_matrices_with_pymc_model(self) -> list[pt.TensorVariable]: return registered_matrices @staticmethod - def _register_kalman_filter_outputs_with_pymc_model(outputs: tuple[pt.TensorVariable]) -> None: + def _register_kalman_filter_outputs_with_pymc_model( + outputs: tuple[pt.TensorVariable], + ) -> None: mod = modelcontext(None) coords = mod.coords @@ -781,7 +802,9 @@ def _register_kalman_filter_outputs_with_pymc_model(outputs: tuple[pt.TensorVari with mod: for var, name in zip(states + covs, state_names + cov_names): dim_names = FILTER_OUTPUT_DIMS.get(name, None) - dims = tuple([dim if dim in coords.keys() else None for dim in dim_names]) + dims = tuple( + [dim if dim in coords.keys() else None for dim in dim_names] + ) pm.Deterministic(name, var, dims=dims) def build_statespace_graph( @@ -870,13 +893,18 @@ def build_statespace_graph( filtered_covariances, predicted_covariances, observed_covariances = covs if save_kalman_filter_outputs_in_idata: smooth_states, smooth_covariances = self._build_smoother_graph( - filtered_states, filtered_covariances, self.unpack_statespace(), mode=mode + filtered_states, + filtered_covariances, + self.unpack_statespace(), + mode=mode, ) all_kf_outputs = states + [smooth_states] + covs + [smooth_covariances] self._register_kalman_filter_outputs_with_pymc_model(all_kf_outputs) obs_dims = FILTER_OUTPUT_DIMS["predicted_observed_state"] - obs_dims = obs_dims if all([dim in pm_mod.coords.keys() for dim in obs_dims]) else None + obs_dims = ( + obs_dims if all([dim in pm_mod.coords.keys() for dim in obs_dims]) else None + ) SequenceMvNormal( "obs", @@ -934,7 +962,13 @@ def _build_smoother_graph( *_, T, Z, R, H, Q = matrices smooth_states, smooth_covariances = self.kalman_smoother.build_graph( - T, R, Q, filtered_states, filtered_covariances, mode=mode, cov_jitter=cov_jitter + T, + R, + Q, + filtered_states, + filtered_covariances, + mode=mode, + cov_jitter=cov_jitter, ) smooth_states.name = "smooth_states" smooth_covariances.name = "smooth_covariances" @@ -963,7 +997,9 @@ def _kalman_filter_outputs_from_dummy_graph( self, data: pt.TensorLike | None = None, data_dims: str | tuple[str] | list[str] | None = None, - ) -> tuple[list[pt.TensorVariable], list[tuple[pt.TensorVariable, pt.TensorVariable]]]: + ) -> tuple[ + list[pt.TensorVariable], list[tuple[pt.TensorVariable, pt.TensorVariable]] + ]: """ Builds a Kalman filter graph using "dummy" pm.Flat distributions for the model variables and sorts the returns into (mean, covariance) pairs for each of filtered, predicted, and smoothed output. @@ -1084,29 +1120,36 @@ def _sample_conditional( group_idata = getattr(idata, group) with pm.Model(coords=self._fit_coords) as forward_model: - [ - x0, - P0, - c, - d, - T, - Z, - R, - H, - Q, - ], grouped_outputs = self._kalman_filter_outputs_from_dummy_graph(data=data) + ( + [ + x0, + P0, + c, + d, + T, + Z, + R, + H, + Q, + ], + grouped_outputs, + ) = self._kalman_filter_outputs_from_dummy_graph(data=data) for name, (mu, cov) in zip(FILTER_OUTPUT_TYPES, grouped_outputs): dummy_ll = pt.zeros_like(mu) state_dims = ( (TIME_DIM, ALL_STATE_DIM) - if all([dim in self._fit_coords for dim in [TIME_DIM, ALL_STATE_DIM]]) + if all( + [dim in self._fit_coords for dim in [TIME_DIM, ALL_STATE_DIM]] + ) else (None, None) ) obs_dims = ( (TIME_DIM, OBS_STATE_DIM) - if all([dim in self._fit_coords for dim in [TIME_DIM, OBS_STATE_DIM]]) + if all( + [dim in self._fit_coords for dim in [TIME_DIM, OBS_STATE_DIM]] + ) else (None, None) ) @@ -1218,10 +1261,17 @@ def _sample_unconditional( else: steps = len(temp_coords[TIME_DIM]) - 1 - if all([dim in self._fit_coords for dim in [TIME_DIM, ALL_STATE_DIM, OBS_STATE_DIM]]): + if all( + [ + dim in self._fit_coords + for dim in [TIME_DIM, ALL_STATE_DIM, OBS_STATE_DIM] + ] + ): dims = [TIME_DIM, ALL_STATE_DIM, OBS_STATE_DIM] - with pm.Model(coords=temp_coords if dims is not None else None) as forward_model: + with pm.Model( + coords=temp_coords if dims is not None else None + ) as forward_model: self._build_dummy_graph() self._insert_random_variables() @@ -1234,7 +1284,8 @@ def _sample_unconditional( if not self.measurement_error: H_jittered = pm.Deterministic( - "H_jittered", pt.specify_shape(stabilize(H), (self.k_endog, self.k_endog)) + "H_jittered", + pt.specify_shape(stabilize(H), (self.k_endog, self.k_endog)), ) matrices = [x0, P0, c, d, T, Z, R, H_jittered, Q] @@ -1470,7 +1521,10 @@ def sample_statespace_matrices( long_name = SHORT_NAME_TO_LONG[short_name] if (long_name in matrix_names) or (short_name in matrix_names): name = long_name if long_name in matrix_names else short_name - dims = [x if x in self._fit_coords else None for x in MATRIX_DIMS[short_name]] + dims = [ + x if x in self._fit_coords else None + for x in MATRIX_DIMS[short_name] + ] pm.Deterministic(name, matrix, dims=dims) # TODO: Remove this after pm.Flat has its initial_value fixed @@ -1565,7 +1619,12 @@ def forecast( filter_time_dim = TIME_DIM dims = None - if all([dim in temp_coords for dim in [filter_time_dim, ALL_STATE_DIM, OBS_STATE_DIM]]): + if all( + [ + dim in temp_coords + for dim in [filter_time_dim, ALL_STATE_DIM, OBS_STATE_DIM] + ] + ): dims = [TIME_DIM, ALL_STATE_DIM, OBS_STATE_DIM] time_index = temp_coords[filter_time_dim] @@ -1599,22 +1658,30 @@ def forecast( temp_coords[TIME_DIM] = forecast_index mu_dims, cov_dims = None, None - if all([dim in self._fit_coords for dim in [TIME_DIM, ALL_STATE_DIM, ALL_STATE_AUX_DIM]]): + if all( + [ + dim in self._fit_coords + for dim in [TIME_DIM, ALL_STATE_DIM, ALL_STATE_AUX_DIM] + ] + ): mu_dims = ["data_time", ALL_STATE_DIM] cov_dims = ["data_time", ALL_STATE_DIM, ALL_STATE_AUX_DIM] with pm.Model(coords=temp_coords) as forecast_model: - [ - x0, - P0, - c, - d, - T, - Z, - R, - H, - Q, - ], grouped_outputs = self._kalman_filter_outputs_from_dummy_graph( + ( + [ + x0, + P0, + c, + d, + T, + Z, + R, + H, + Q, + ], + grouped_outputs, + ) = self._kalman_filter_outputs_from_dummy_graph( data_dims=["data_time", OBS_STATE_DIM] ) group_idx = FILTER_OUTPUT_TYPES.index(filter_output) @@ -1622,10 +1689,14 @@ def forecast( mu, cov = grouped_outputs[group_idx] x0 = pm.Deterministic( - "x0_slice", mu[t0_idx], dims=mu_dims[1:] if mu_dims is not None else None + "x0_slice", + mu[t0_idx], + dims=mu_dims[1:] if mu_dims is not None else None, ) P0 = pm.Deterministic( - "P0_slice", cov[t0_idx], dims=cov_dims[1:] if cov_dims is not None else None + "P0_slice", + cov[t0_idx], + dims=cov_dims[1:] if cov_dims is not None else None, ) _ = LinearGaussianStateSpace( @@ -1736,7 +1807,9 @@ def impulse_response_function( Q = None # No covariance matrix needed if a trajectory is provided. Will be overwritten later if needed. if n_options > 1: - raise ValueError("Specify exactly 0 or 1 of shock_size, shock_cov, or shock_trajectory") + raise ValueError( + "Specify exactly 0 or 1 of shock_size, shock_cov, or shock_trajectory" + ) elif n_options == 1: # If the user passed an alternative parameterization for the shocks of the IRF, don't use the posterior use_posterior_cov = False @@ -1767,7 +1840,9 @@ def impulse_response_function( self._insert_random_variables() P0, _, c, d, T, Z, R, H, post_Q = self.unpack_statespace() - x0 = pm.Deterministic("x0_new", pt.zeros(self.k_states), dims=[ALL_STATE_DIM]) + x0 = pm.Deterministic( + "x0_new", pt.zeros(self.k_states), dims=[ALL_STATE_DIM] + ) if use_posterior_cov: Q = post_Q @@ -1781,7 +1856,9 @@ def impulse_response_function( if shock_trajectory is None: shock_trajectory = pt.zeros((n_steps, self.k_posdef)) if Q is not None: - init_shock = pm.MvNormal("initial_shock", mu=0, cov=Q, dims=[SHOCK_DIM]) + init_shock = pm.MvNormal( + "initial_shock", mu=0, cov=Q, dims=[SHOCK_DIM] + ) else: init_shock = pm.Deterministic( "initial_shock", diff --git a/pymc_experimental/statespace/filters/distributions.py b/pymc_experimental/statespace/filters/distributions.py index edcc00c6e..b1f349390 100644 --- a/pymc_experimental/statespace/filters/distributions.py +++ b/pymc_experimental/statespace/filters/distributions.py @@ -43,7 +43,9 @@ def make_signature(sequence_names): base_shape = matrix_to_shape[matrix] matrix_to_shape[matrix] = (time,) + base_shape - signature = ",".join(["(" + ",".join(shapes) + ")" for shapes in matrix_to_shape.values()]) + signature = ",".join( + ["(" + ",".join(shapes) + ")" for shapes in matrix_to_shape.values()] + ) return f"{signature},[rng]->[rng],({time},{state_and_obs})" @@ -87,7 +89,10 @@ def sample_fn(rng, size, dtype, *parameters): class LinearGaussianStateSpaceRV(SymbolicRandomVariable): default_output = 1 - _print_name = ("LinearGuassianStateSpace", "\\operatorname{LinearGuassianStateSpace}") + _print_name = ( + "LinearGuassianStateSpace", + "\\operatorname{LinearGuassianStateSpace}", + ) def update(self, node: Node): return {node.inputs[-1]: node.outputs[0]} @@ -143,7 +148,20 @@ def __new__( @classmethod def dist( - cls, a0, P0, c, d, T, Z, R, H, Q, steps=None, mode=None, sequence_names=None, **kwargs + cls, + a0, + P0, + c, + d, + T, + Z, + R, + H, + Q, + steps=None, + mode=None, + sequence_names=None, + **kwargs, ): steps = get_support_shape_1d( support_shape=steps, shape=kwargs.get("shape", None), support_shape_offset=0 @@ -155,11 +173,29 @@ def dist( steps = pt.as_tensor_variable(intX(steps), ndim=0) return super().dist( - [a0, P0, c, d, T, Z, R, H, Q, steps], mode=mode, sequence_names=sequence_names, **kwargs + [a0, P0, c, d, T, Z, R, H, Q, steps], + mode=mode, + sequence_names=sequence_names, + **kwargs, ) @classmethod - def rv_op(cls, a0, P0, c, d, T, Z, R, H, Q, steps, size=None, mode=None, sequence_names=None): + def rv_op( # noqa: F811 + cls, + a0, + P0, + c, + d, + T, + Z, + R, + H, + Q, + steps, + size=None, + mode=None, + sequence_names=None, + ): if sequence_names is None: sequence_names = [] @@ -177,7 +213,9 @@ def rv_op(cls, a0, P0, c, d, T, Z, R, H, Q, steps, size=None, mode=None, sequenc sequences = [ x - for x, name in zip([c_, d_, T_, Z_, R_, H_, Q_], ["c", "d", "T", "Z", "R", "H", "Q"]) + for x, name in zip( + [c_, d_, T_, Z_, R_, H_, Q_], ["c", "d", "T", "Z", "R", "H", "Q"] + ) if name in sequence_names ] non_sequences = [x for x in [c_, d_, T_, Z_, R_, H_, Q_] if x not in sequences] @@ -208,8 +246,12 @@ def step_fn(*args): k = T.shape[0] a = state[:k] - middle_rng, a_innovation = MvNormalSVD.dist(mu=0, cov=Q, rng=rng).owner.outputs - next_rng, y_innovation = MvNormalSVD.dist(mu=0, cov=H, rng=middle_rng).owner.outputs + middle_rng, a_innovation = MvNormalSVD.dist( + mu=0, cov=Q, rng=rng + ).owner.outputs + next_rng, y_innovation = MvNormalSVD.dist( + mu=0, cov=H, rng=middle_rng + ).owner.outputs a_mu = c + T @ a a_next = a_mu + R @ a_innovation @@ -249,7 +291,9 @@ def step_fn(*args): extended_signature=make_signature(sequence_names), ) - linear_gaussian_ss = linear_gaussian_ss_op(a0, P0, c, d, T, Z, R, H, Q, steps, rng) + linear_gaussian_ss = linear_gaussian_ss_op( + a0, P0, c, d, T, Z, R, H, Q, steps, rng + ) return linear_gaussian_ss @@ -290,8 +334,6 @@ def __new__( latent_dims = [time_dim, state_dim] obs_dims = [time_dim, obs_dim] - matrices = () - latent_obs_combined = _LinearGaussianStateSpace( f"{name}_combined", a0, @@ -317,7 +359,9 @@ def __new__( latent_states = latent_obs_combined[..., latent_slice] obs_states = latent_obs_combined[..., obs_slice] - latent_states = pm.Deterministic(f"{name}_latent", latent_states, dims=latent_dims) + latent_states = pm.Deterministic( + f"{name}_latent", latent_states, dims=latent_dims + ) obs_states = pm.Deterministic(f"{name}_observed", obs_states, dims=obs_dims) return latent_states, obs_states @@ -370,7 +414,7 @@ def dist(cls, mus, covs, logp, **kwargs): return super().dist([mus, covs, logp], **kwargs) @classmethod - def rv_op(cls, mus, covs, logp, size=None): + def rv_op(cls, mus, covs, logp, size=None): # noqa: F811 # Batch dimensions (if any) will be on the far left, but scan requires time to be there instead if mus.ndim > 2: mus = pt.moveaxis(mus, -2, 0) @@ -387,7 +431,11 @@ def step(mu, cov, rng): return mvn, {rng: new_rng} mvn_seq, updates = pytensor.scan( - step, sequences=[mus_, covs_], non_sequences=[rng], strict=True, n_steps=mus_.shape[0] + step, + sequences=[mus_, covs_], + non_sequences=[rng], + strict=True, + n_steps=mus_.shape[0], ) mvn_seq = pt.specify_shape(mvn_seq, mus.type.shape) @@ -398,7 +446,9 @@ def step(mu, cov, rng): (seq_mvn_rng,) = tuple(updates.values()) mvn_seq_op = KalmanFilterRV( - inputs=[mus_, covs_, logp_, rng], outputs=[seq_mvn_rng, mvn_seq], ndim_supp=2 + inputs=[mus_, covs_, logp_, rng], + outputs=[seq_mvn_rng, mvn_seq], + ndim_supp=2, ) mvn_seq = mvn_seq_op(mus, covs, logp, rng) diff --git a/pymc_experimental/statespace/filters/kalman_filter.py b/pymc_experimental/statespace/filters/kalman_filter.py index 87fdc746a..bfdefead7 100644 --- a/pymc_experimental/statespace/filters/kalman_filter.py +++ b/pymc_experimental/statespace/filters/kalman_filter.py @@ -21,7 +21,9 @@ MVN_CONST = pt.log(2 * pt.constant(np.pi, dtype="float64")) PARAM_NAMES = ["c", "d", "T", "Z", "R", "H", "Q"] -assert_data_is_1d = Assert("UnivariateTimeSeries filter requires data be at most 1-dimensional") +assert_data_is_1d = Assert( + "UnivariateTimeSeries filter requires data be at most 1-dimensional" +) assert_time_varying_dim_correct = Assert( "The first dimension of a time varying matrix (the time dimension) must be " "equal to the first dimension of the data (the time dimension)." @@ -107,7 +109,11 @@ def initialize_eyes(self, R: TensorVariable, Z: TensorVariable) -> None: 2nd ed, Oxford University Press, 2012. """ - self.n_states, self.n_posdef, self.n_endog = R.shape[-2], R.shape[-1], Z.shape[-2] + self.n_states, self.n_posdef, self.n_endog = ( + R.shape[-2], + R.shape[-1], + Z.shape[-2], + ) self.eye_states = pt.eye(self.n_states) self.eye_posdef = pt.eye(self.n_posdef) self.eye_endog = pt.eye(self.n_endog) @@ -172,7 +178,11 @@ def unpack_args(self, args) -> tuple: y = args.pop(0) # There are always two outputs_info wedged between the seqs and non_seqs - seqs, (a0, P0), non_seqs = args[:n_seq], args[n_seq : n_seq + 2], args[n_seq + 2 :] + seqs, (a0, P0), non_seqs = ( + args[:n_seq], + args[n_seq : n_seq + 2], + args[n_seq + 2 :], + ) return_ordered = [] for name in ["c", "d", "T", "Z", "R", "H", "Q"]: if name in self.seq_names: @@ -251,8 +261,8 @@ def build_graph( data, a0, P0, *params = self.check_params(data, a0, P0, c, d, T, Z, R, H, Q) - sequences, non_sequences, seq_names, non_seq_names = split_vars_into_seq_and_nonseq( - params, PARAM_NAMES + sequences, non_sequences, seq_names, non_seq_names = ( + split_vars_into_seq_and_nonseq(params, PARAM_NAMES) ) self.seq_names = seq_names @@ -271,7 +281,9 @@ def build_graph( strict=False, ) - filter_results = self._postprocess_scan_results(results, a0, P0, n=data.type.shape[0]) + filter_results = self._postprocess_scan_results( + results, a0, P0, n=data.type.shape[0] + ) if return_updates: return filter_results, updates @@ -449,7 +461,9 @@ def predict(a, P, c, T, R, Q) -> tuple[TensorVariable, TensorVariable]: @staticmethod def update( a, P, y, c, d, Z, H, all_nan_flag - ) -> tuple[TensorVariable, TensorVariable, TensorVariable, TensorVariable, TensorVariable]: + ) -> tuple[ + TensorVariable, TensorVariable, TensorVariable, TensorVariable, TensorVariable + ]: """ Perform the update step of the Kalman filter. @@ -569,7 +583,14 @@ def kalman_step(self, *args) -> tuple: y_masked, Z_masked, H_masked, all_nan_flag = self.handle_missing_values(y, Z, H) a_filtered, P_filtered, obs_mu, obs_cov, ll = self.update( - y=y_masked, a=a, c=c, d=d, P=P, Z=Z_masked, H=H_masked, all_nan_flag=all_nan_flag + y=y_masked, + a=a, + c=c, + d=d, + P=P, + Z=Z_masked, + H=H_masked, + all_nan_flag=all_nan_flag, ) P_filtered = stabilize(P_filtered, self.cov_jitter) @@ -695,7 +716,8 @@ def update(self, a, P, y, c, d, Z, H, all_nan_flag): all_nan_flag, 0.0, ( - -0.5 * (n * MVN_CONST + (v.T @ inner_term).ravel()) - pt.log(pt.diag(F_chol)).sum() + -0.5 * (n * MVN_CONST + (v.T @ inner_term).ravel()) + - pt.log(pt.diag(F_chol)).sum() ).ravel()[0], ) @@ -735,7 +757,9 @@ def update(self, a, P, y, c, d, Z, H, all_nan_flag): P_filtered = quad_form_sym(I_KZ, P) + quad_form_sym(K, H) - ll = pt.switch(all_nan_flag, 0.0, -0.5 * (MVN_CONST + pt.log(F) + v**2 / F)).ravel()[0] + ll = pt.switch( + all_nan_flag, 0.0, -0.5 * (MVN_CONST + pt.log(F) + v**2 / F) + ).ravel()[0] return a_filtered, P_filtered, pt.atleast_1d(y_hat), pt.atleast_2d(F), ll @@ -783,8 +807,8 @@ def build_graph( self.initialize_eyes(R, Z) data, a0, P0, *params = self.check_params(data, a0, P0, c, d, T, Z, R, H, Q) - sequences, non_sequences, seq_names, non_seq_names = split_vars_into_seq_and_nonseq( - params, PARAM_NAMES + sequences, non_sequences, seq_names, non_seq_names = ( + split_vars_into_seq_and_nonseq(params, PARAM_NAMES) ) self.seq_names = seq_names self.non_seq_names = non_seq_names @@ -797,7 +821,9 @@ def build_graph( P_steady = solve_discrete_are(T.T, Z.T, matrix_dot(R, Q, R.T), H) F = matrix_dot(Z, P_steady, Z.T) + H - F_inv = pt.linalg.solve(F, pt.eye(F.shape[0]), assume_a="pos", check_finite=False) + F_inv = pt.linalg.solve( + F, pt.eye(F.shape[0]), assume_a="pos", check_finite=False + ) results, updates = pytensor.scan( self.kalman_step, diff --git a/pymc_experimental/statespace/filters/kalman_smoother.py b/pymc_experimental/statespace/filters/kalman_smoother.py index 32581fa4f..2ab07b5a4 100644 --- a/pymc_experimental/statespace/filters/kalman_smoother.py +++ b/pymc_experimental/statespace/filters/kalman_smoother.py @@ -65,7 +65,14 @@ def unpack_args(self, args): return a, P, a_smooth, P_smooth, T, R, Q def build_graph( - self, T, R, Q, filtered_states, filtered_covariances, mode=None, cov_jitter=JITTER_DEFAULT + self, + T, + R, + Q, + filtered_states, + filtered_covariances, + mode=None, + cov_jitter=JITTER_DEFAULT, ): self.mode = mode self.cov_jitter = cov_jitter @@ -75,8 +82,8 @@ def build_graph( a_last = pt.specify_shape(filtered_states[-1], (k,)) P_last = pt.specify_shape(filtered_covariances[-1], (k, k)) - sequences, non_sequences, seq_names, non_seq_names = split_vars_into_seq_and_nonseq( - [T, R, Q], ["T", "R", "Q"] + sequences, non_sequences, seq_names, non_seq_names = ( + split_vars_into_seq_and_nonseq([T, R, Q], ["T", "R", "Q"]) ) self.seq_names = seq_names diff --git a/pymc_experimental/statespace/models/SARIMAX.py b/pymc_experimental/statespace/models/SARIMAX.py index 840eb0acb..d221a0fda 100644 --- a/pymc_experimental/statespace/models/SARIMAX.py +++ b/pymc_experimental/statespace/models/SARIMAX.py @@ -372,7 +372,9 @@ def _stationary_initialization(self, mode=None): Q = self.ssm["state_cov"] c = self.ssm["state_intercept"] - x0 = pt.linalg.solve(pt.identity_like(T) - T, c, assume_a="gen", check_finite=True) + x0 = pt.linalg.solve( + pt.identity_like(T) - T, c, assume_a="gen", check_finite=True + ) method = "direct" if (self.k_states < 5) or (mode == "JAX") else "bilinear" P0 = solve_discrete_lyapunov(T, pt.linalg.matrix_dot(R, Q, R.T), method=method) @@ -385,7 +387,9 @@ def make_symbolic_graph(self) -> None: # Initial state and covariance can be handled first if we're not doing a stationary initialization if not self.stationary_initialization: - x0 = self.make_and_register_variable("x0", shape=(self.k_states,), dtype=floatX) + x0 = self.make_and_register_variable( + "x0", shape=(self.k_states,), dtype=floatX + ) P0 = self.make_and_register_variable( "P0", shape=(self.k_states, self.k_states), dtype=floatX ) @@ -395,9 +399,9 @@ def make_symbolic_graph(self) -> None: # Design matrix has no RVs k_lags = self.k_states - self._k_diffs - self.ssm["design"] = np.r_[[1] * d, ([0] * (S - 1) + [1]) * D, [1], [0] * (k_lags - 1)][ - None - ] + self.ssm["design"] = np.r_[ + [1] * d, ([0] * (S - 1) + [1]) * D, [1], [0] * (k_lags - 1) + ][None] # Set up the transition and selection matrices, depending on the requested representation if self.state_structure == "fast": @@ -409,13 +413,17 @@ def make_symbolic_graph(self) -> None: ar_param_idx = np.s_[ "transition", self._k_diffs : self._k_diffs + self.p, self._k_diffs ] - ma_param_idx = np.s_["selection", 1 + self._k_diffs : 1 + self._k_diffs + self.q, 0] + ma_param_idx = np.s_[ + "selection", 1 + self._k_diffs : 1 + self._k_diffs + self.q, 0 + ] self.ssm["transition"] = transition self.ssm["selection"] = selection if p > 0: - ar_params = self.make_and_register_variable("ar_params", shape=(p,), dtype=floatX) + ar_params = self.make_and_register_variable( + "ar_params", shape=(p,), dtype=floatX + ) self.ssm[ar_param_idx] = ar_params if P > 0: @@ -432,12 +440,14 @@ def make_symbolic_graph(self) -> None: idx_rows.repeat(p) + np.tile(np.arange(p), P) + 1, self._k_diffs, ] - self.ssm[cross_term_idx] = -pt.repeat(seasonal_ar_params, p) * pt.tile( - ar_params, P - ) + self.ssm[cross_term_idx] = -pt.repeat( + seasonal_ar_params, p + ) * pt.tile(ar_params, P) if q > 0: - ma_params = self.make_and_register_variable("ma_params", shape=(q,), dtype=floatX) + ma_params = self.make_and_register_variable( + "ma_params", shape=(q,), dtype=floatX + ) self.ssm[ma_param_idx] = ma_params if Q > 0: @@ -450,11 +460,13 @@ def make_symbolic_graph(self) -> None: if q > 0: cross_term_idx = np.s_[ - "selection", idx_rows.repeat(q) + np.tile(np.arange(q), Q) + 1, 0 + "selection", + idx_rows.repeat(q) + np.tile(np.arange(q), Q) + 1, + 0, ] - self.ssm[cross_term_idx] = pt.repeat(seasonal_ma_params, q) * pt.tile( - ma_params, Q - ) + self.ssm[cross_term_idx] = pt.repeat( + seasonal_ma_params, q + ) * pt.tile(ma_params, Q) elif self.state_structure == "interpretable": ar_param_idx = np.s_["transition", 0, : max(1, p)] @@ -485,11 +497,13 @@ def make_symbolic_graph(self) -> None: if p > 0: cross_term_idx = np.s_[ - "transition", 0, idx_cols.repeat(p) + np.tile(np.arange(p), P) + 1 + "transition", + 0, + idx_cols.repeat(p) + np.tile(np.arange(p), P) + 1, ] - self.ssm[cross_term_idx] = -pt.repeat(seasonal_ar_params, p) * pt.tile( - ar_params, P - ) + self.ssm[cross_term_idx] = -pt.repeat( + seasonal_ar_params, p + ) * pt.tile(ar_params, P) if self.q > 0: ma_params = self.make_and_register_variable( @@ -507,23 +521,29 @@ def make_symbolic_graph(self) -> None: if q > 0: cross_term_idx = np.s_[ - "transition", 0, idx_cols.repeat(q) + np.tile(np.arange(q), Q) + 1 + "transition", + 0, + idx_cols.repeat(q) + np.tile(np.arange(q), Q) + 1, ] - self.ssm[cross_term_idx] = pt.repeat(seasonal_ma_params, q) * pt.tile( - ma_params, Q - ) + self.ssm[cross_term_idx] = pt.repeat( + seasonal_ma_params, q + ) * pt.tile(ma_params, Q) # Set up the state covariance matrix state_cov_idx = ("state_cov",) + np.diag_indices(self.k_posdef) state_cov = self.make_and_register_variable( - "sigma_state", shape=() if self.k_posdef == 1 else (self.k_posdef,), dtype=floatX + "sigma_state", + shape=() if self.k_posdef == 1 else (self.k_posdef,), + dtype=floatX, ) self.ssm[state_cov_idx] = state_cov**2 if self.measurement_error: obs_cov_idx = ("obs_cov",) + np.diag_indices(self.k_endog) obs_cov = self.make_and_register_variable( - "sigma_obs", shape=() if self.k_endog == 1 else (self.k_endog,), dtype=floatX + "sigma_obs", + shape=() if self.k_endog == 1 else (self.k_endog,), + dtype=floatX, ) self.ssm[obs_cov_idx] = obs_cov**2 diff --git a/pymc_experimental/statespace/models/VARMAX.py b/pymc_experimental/statespace/models/VARMAX.py index 0942d6db3..56ea08761 100644 --- a/pymc_experimental/statespace/models/VARMAX.py +++ b/pymc_experimental/statespace/models/VARMAX.py @@ -156,7 +156,9 @@ def __init__( endog_names = [f"state.{i + 1}" for i in range(k_endog)] if (endog_names is not None) and (k_endog is not None): if len(endog_names) != k_endog: - raise ValueError("Length of provided endog_names does not match provided k_endog") + raise ValueError( + "Length of provided endog_names does not match provided k_endog" + ) self.endog_names = list(endog_names) self.p, self.q = order @@ -240,7 +242,9 @@ def state_names(self): f"L{i + 1}.{state}" for i in range(self.p - 1) for state in self.endog_names ] state_names += [ - f"L{i + 1}.{state}_innov" for i in range(self.q) for state in self.endog_names + f"L{i + 1}.{state}_innov" + for i in range(self.q) + for state in self.endog_names ] return state_names @@ -297,7 +301,9 @@ def make_symbolic_graph(self) -> None: # Initialize the matrices if not self.stationary_initialization: # initial states - x0 = self.make_and_register_variable("x0", shape=(self.k_states,), dtype=floatX) + x0 = self.make_and_register_variable( + "x0", shape=(self.k_states,), dtype=floatX + ) self.ssm["initial_state", :] = x0 # initial covariance @@ -330,7 +336,11 @@ def make_symbolic_graph(self) -> None: self.ssm[("transition",) + idx] = np.eye(self.k_endog * (self.q - 1)) if self.p > 0: - ar_param_idx = ("transition", slice(0, self.k_endog), slice(0, self.k_endog * self.p)) + ar_param_idx = ( + "transition", + slice(0, self.k_endog), + slice(0, self.k_endog * self.p), + ) # Register the AR parameter matrix as a (k, p, k), then reshape it and allocate it in the transition matrix # This way the user can use 3 dimensions in the prior (clearer?) @@ -361,7 +371,9 @@ def make_symbolic_graph(self) -> None: self.ssm[ma_param_idx] = ma_params end = -self.k_endog * (self.q - 1) if self.q > 1 else None - self.ssm["selection", slice(self.k_endog * -self.q, end), :] = np.eye(self.k_endog) + self.ssm["selection", slice(self.k_endog * -self.q, end), :] = np.eye( + self.k_endog + ) if self.measurement_error: obs_cov_idx = ("obs_cov",) + np.diag_indices(self.k_endog) @@ -382,7 +394,9 @@ def make_symbolic_graph(self) -> None: Q = self.ssm["state_cov"] c = self.ssm["state_intercept"] - x0 = pt.linalg.solve(pt.eye(T.shape[0]) - T, c, assume_a="gen", check_finite=False) + x0 = pt.linalg.solve( + pt.eye(T.shape[0]) - T, c, assume_a="gen", check_finite=False + ) P0 = solve_discrete_lyapunov( T, pt.linalg.matrix_dot(R, Q, R.T), diff --git a/pymc_experimental/statespace/models/structural.py b/pymc_experimental/statespace/models/structural.py index d5985adcb..2e6b1a7c8 100644 --- a/pymc_experimental/statespace/models/structural.py +++ b/pymc_experimental/statespace/models/structural.py @@ -133,7 +133,9 @@ def __init__( self.ssm["initial_state_cov"] = P0 @staticmethod - def _add_inital_state_cov_to_properties(param_names, param_dims, param_info, k_states): + def _add_inital_state_cov_to_properties( + param_names, param_dims, param_info, k_states + ): param_names += ["P0"] param_dims["P0"] = (ALL_STATE_DIM, ALL_STATE_AUX_DIM) param_info["P0"] = { @@ -377,7 +379,9 @@ def __init__( self.param_counts = {} if representation is None: - self.ssm = PytensorRepresentation(k_endog=k_endog, k_states=k_states, k_posdef=k_posdef) + self.ssm = PytensorRepresentation( + k_endog=k_endog, k_states=k_states, k_posdef=k_posdef + ) else: self.ssm = representation @@ -515,13 +519,17 @@ def make_slice(name, x, o_x): for name, x, o_x in zip(LONG_MATRIX_NAMES, self_matrices, other_matrices) ) - initial_state = pt.concatenate(conform_time_varying_and_time_invariant_matrices(x0, o_x0)) + initial_state = pt.concatenate( + conform_time_varying_and_time_invariant_matrices(x0, o_x0) + ) initial_state.name = x0.name initial_state_cov = pt.linalg.block_diag(P0, o_P0) initial_state_cov.name = P0.name - state_intercept = pt.concatenate(conform_time_varying_and_time_invariant_matrices(c, o_c)) + state_intercept = pt.concatenate( + conform_time_varying_and_time_invariant_matrices(c, o_c) + ) state_intercept.name = c.name obs_intercept = d + o_d @@ -530,7 +538,9 @@ def make_slice(name, x, o_x): transition = pt.linalg.block_diag(T, o_T) transition.name = T.name - design = pt.concatenate(conform_time_varying_and_time_invariant_matrices(Z, o_Z), axis=-1) + design = pt.concatenate( + conform_time_varying_and_time_invariant_matrices(Z, o_Z), axis=-1 + ) design.name = Z.name selection = pt.linalg.block_diag(R, o_R) @@ -833,10 +843,14 @@ def __init__( def populate_component_properties(self): name_slice = POSITION_DERIVATIVE_NAMES[: self.k_states] self.param_names = ["initial_trend"] - self.state_names = [name for name, mask in zip(name_slice, self._order_mask) if mask] + self.state_names = [ + name for name, mask in zip(name_slice, self._order_mask) if mask + ] self.param_dims = {"initial_trend": ("trend_state",)} self.coords = {"trend_state": self.state_names} - self.param_info = {"initial_trend": {"shape": (self.k_states,), "constraints": None}} + self.param_info = { + "initial_trend": {"shape": (self.k_states,), "constraints": None} + } if self.k_posdef > 0: self.param_names += ["sigma_trend"] @@ -845,13 +859,18 @@ def populate_component_properties(self): ] self.param_dims["sigma_trend"] = ("trend_shock",) self.coords["trend_shock"] = self.shock_names - self.param_info["sigma_trend"] = {"shape": (self.k_posdef,), "constraints": "Positive"} + self.param_info["sigma_trend"] = { + "shape": (self.k_posdef,), + "constraints": "Positive", + } for name in self.param_names: self.param_info[name]["dims"] = self.param_dims[name] def make_symbolic_graph(self) -> None: - initial_trend = self.make_and_register_variable("initial_trend", shape=(self.k_states,)) + initial_trend = self.make_and_register_variable( + "initial_trend", shape=(self.k_states,) + ) self.ssm["initial_state", :] = initial_trend triu_idx = np.triu_indices(self.k_states) self.ssm[np.s_["transition", triu_idx[0], triu_idx[1]]] = 1 @@ -863,7 +882,9 @@ def make_symbolic_graph(self) -> None: self.ssm["design", 0, :] = np.array([1.0] + [0.0] * (self.k_states - 1)) if self.k_posdef > 0: - sigma_trend = self.make_and_register_variable("sigma_trend", shape=(self.k_posdef,)) + sigma_trend = self.make_and_register_variable( + "sigma_trend", shape=(self.k_posdef,) + ) diag_idx = np.diag_indices(self.k_posdef) idx = np.s_["state_cov", diag_idx[0], diag_idx[1]] self.ssm[idx] = sigma_trend**2 @@ -914,7 +935,12 @@ def __init__(self, name: str = "MeasurementError"): k_posdef = 0 super().__init__( - name, k_endog, k_states, k_posdef, measurement_error=True, combine_hidden_states=False + name, + k_endog, + k_states, + k_posdef, + measurement_error=True, + combine_hidden_states=False, ) def populate_component_properties(self): @@ -930,7 +956,9 @@ def populate_component_properties(self): def make_symbolic_graph(self) -> None: sigma_shape = () - error_sigma = self.make_and_register_variable(f"sigma_{self.name}", shape=sigma_shape) + error_sigma = self.make_and_register_variable( + f"sigma_{self.name}", shape=sigma_shape + ) diag_idx = np.diag_indices(self.k_endog) idx = np.s_["obs_cov", diag_idx[0], diag_idx[1]] self.ssm[idx] = error_sigma**2 @@ -1036,7 +1064,11 @@ def make_symbolic_graph(self) -> None: self.ssm["selection", 0, 0] = 1 self.ssm["design", 0, 0] = 1 - ar_idx = ("transition", np.zeros(k_nonzero, dtype="int"), np.nonzero(self.order)[0]) + ar_idx = ( + "transition", + np.zeros(k_nonzero, dtype="int"), + np.nonzero(self.order)[0], + ) self.ssm[ar_idx] = ar_params cov_idx = ("state_cov", *np.diag_indices(1)) @@ -1226,7 +1258,9 @@ def make_symbolic_graph(self) -> None: if self.innovations: self.ssm["selection", 0, 0] = 1 - season_sigma = self.make_and_register_variable(f"sigma_{self.name}", shape=()) + season_sigma = self.make_and_register_variable( + f"sigma_{self.name}", shape=() + ) cov_idx = ("state_cov", *np.diag_indices(1)) self.ssm[cov_idx] = season_sigma**2 @@ -1315,22 +1349,31 @@ def __init__(self, season_length, n=None, name=None, innovations=True): def make_symbolic_graph(self) -> None: self.ssm["design", 0, slice(0, self.k_states, 2)] = 1 - init_state = self.make_and_register_variable(f"{self.name}", shape=(self.n_coefs,)) + init_state = self.make_and_register_variable( + f"{self.name}", shape=(self.n_coefs,) + ) init_state_idx = np.arange(self.n_coefs, dtype=int) self.ssm["initial_state", init_state_idx] = init_state - T_mats = [_frequency_transition_block(self.season_length, j + 1) for j in range(self.n)] + T_mats = [ + _frequency_transition_block(self.season_length, j + 1) + for j in range(self.n) + ] T = pt.linalg.block_diag(*T_mats) self.ssm["transition", :, :] = T if self.innovations: - sigma_season = self.make_and_register_variable(f"sigma_{self.name}", shape=()) + sigma_season = self.make_and_register_variable( + f"sigma_{self.name}", shape=() + ) self.ssm["state_cov", :, :] = pt.eye(self.k_posdef) * sigma_season**2 self.ssm["selection", :, :] = np.eye(self.k_states) def populate_component_properties(self): - self.state_names = [f"{self.name}_{f}_{i}" for i in range(self.n) for f in ["Cos", "Sin"]] + self.state_names = [ + f"{self.name}_{f}_{i}" for i in range(self.n) for f in ["Cos", "Sin"] + ] self.param_names = [f"{self.name}"] self.param_dims = {self.name: (f"{self.name}_state",)} @@ -1345,7 +1388,9 @@ def populate_component_properties(self): init_state_idx = np.arange(self.k_states, dtype=int) if self.last_state_not_identified: init_state_idx = init_state_idx[:-1] - self.coords = {f"{self.name}_state": [self.state_names[i] for i in init_state_idx]} + self.coords = { + f"{self.name}_state": [self.state_names[i] for i in init_state_idx] + } if self.innovations: self.shock_names = self.state_names.copy() @@ -1451,9 +1496,13 @@ def __init__( innovations: bool = True, ): if cycle_length is None and not estimate_cycle_length: - raise ValueError("Must specify cycle_length if estimate_cycle_length is False") + raise ValueError( + "Must specify cycle_length if estimate_cycle_length is False" + ) if cycle_length is not None and estimate_cycle_length: - raise ValueError("Cannot specify cycle_length if estimate_cycle_length is True") + raise ValueError( + "Cannot specify cycle_length if estimate_cycle_length is True" + ) if name is None: cycle = int(cycle_length) if cycle_length is not None else "Estimate" name = f"Cycle[s={cycle}, dampen={dampen}, innovations={innovations}]" @@ -1487,7 +1536,9 @@ def make_symbolic_graph(self) -> None: self.param_dims = {self.name: (f"{self.name}_state",)} self.coords = {f"{self.name}_state": self.state_names} - init_state = self.make_and_register_variable(f"{self.name}", shape=(self.k_states,)) + init_state = self.make_and_register_variable( + f"{self.name}", shape=(self.k_states,) + ) self.ssm["initial_state", :] = init_state @@ -1497,7 +1548,9 @@ def make_symbolic_graph(self) -> None: lamb = self.cycle_length if self.dampen: - rho = self.make_and_register_variable(f"{self.name}_dampening_factor", shape=()) + rho = self.make_and_register_variable( + f"{self.name}_dampening_factor", shape=() + ) else: rho = 1 @@ -1505,7 +1558,9 @@ def make_symbolic_graph(self) -> None: self.ssm["transition", :, :] = T if self.innovations: - sigma_cycle = self.make_and_register_variable(f"sigma_{self.name}", shape=()) + sigma_cycle = self.make_and_register_variable( + f"sigma_{self.name}", shape=() + ) self.ssm["state_cov", :, :] = pt.eye(self.k_posdef) * sigma_cycle**2 def populate_component_properties(self): @@ -1574,12 +1629,16 @@ def __init__( ) @staticmethod - def _get_state_names(k_exog: Optional[int], state_names: Optional[list[str]], name: str): + def _get_state_names( + k_exog: Optional[int], state_names: Optional[list[str]], name: str + ): if k_exog is None and state_names is None: raise ValueError("Must specify at least one of k_exog or state_names") if state_names is not None and k_exog is not None: if len(state_names) != k_exog: - raise ValueError(f"Expected {k_exog} state names, found {len(state_names)}") + raise ValueError( + f"Expected {k_exog} state names, found {len(state_names)}" + ) elif k_exog is None: k_exog = len(state_names) else: @@ -1587,14 +1646,18 @@ def _get_state_names(k_exog: Optional[int], state_names: Optional[list[str]], na return k_exog, state_names - def _handle_input_data(self, k_exog: int, state_names: Optional[list[str]], name) -> int: + def _handle_input_data( + self, k_exog: int, state_names: Optional[list[str]], name + ) -> int: k_exog, state_names = self._get_state_names(k_exog, state_names, name) self.state_names = state_names return k_exog def make_symbolic_graph(self) -> None: - betas = self.make_and_register_variable(f"beta_{self.name}", shape=(self.k_states,)) + betas = self.make_and_register_variable( + f"beta_{self.name}", shape=(self.k_states,) + ) regression_data = self.make_and_register_data( f"data_{self.name}", shape=(None, self.k_states) ) @@ -1634,7 +1697,7 @@ def populate_component_properties(self) -> None: "dims": (TIME_DIM, "exog_state"), }, } - self.coords = {f"exog_state": self.state_names} + self.coords = {"exog_state": self.state_names} if self.innovations: self.param_names += [f"sigma_beta_{self.name}"] diff --git a/pymc_experimental/statespace/models/utilities.py b/pymc_experimental/statespace/models/utilities.py index 40f38093e..8bffcd38a 100644 --- a/pymc_experimental/statespace/models/utilities.py +++ b/pymc_experimental/statespace/models/utilities.py @@ -58,7 +58,9 @@ def cleanup_states(states: list[str]) -> list[str]: return out -def make_harvey_state_names(p: int, d: int, q: int, P: int, D: int, Q: int, S: int) -> list[str]: +def make_harvey_state_names( + p: int, d: int, q: int, P: int, D: int, Q: int, S: int +) -> list[str]: """ Generate informative names for the SARIMA states in the Harvey representation diff --git a/pymc_experimental/statespace/utils/constants.py b/pymc_experimental/statespace/utils/constants.py index 3482b0748..5ca56fb84 100644 --- a/pymc_experimental/statespace/utils/constants.py +++ b/pymc_experimental/statespace/utils/constants.py @@ -69,5 +69,13 @@ "predicted_observed_covariance": (TIME_DIM, OBS_STATE_DIM, OBS_STATE_AUX_DIM), } -POSITION_DERIVATIVE_NAMES = ["level", "trend", "acceleration", "jerk", "snap", "crackle", "pop"] +POSITION_DERIVATIVE_NAMES = [ + "level", + "trend", + "acceleration", + "jerk", + "snap", + "crackle", + "pop", +] SARIMAX_STATE_STRUCTURES = ["fast", "interpretable"] diff --git a/pymc_experimental/statespace/utils/data_tools.py b/pymc_experimental/statespace/utils/data_tools.py index 67ab3d89a..80949f805 100644 --- a/pymc_experimental/statespace/utils/data_tools.py +++ b/pymc_experimental/statespace/utils/data_tools.py @@ -34,7 +34,9 @@ def get_data_dims(data): return data_dims -def _validate_data_shape(data_shape, n_obs, obs_coords=None, check_col_names=False, col_names=None): +def _validate_data_shape( + data_shape, n_obs, obs_coords=None, check_col_names=False, col_names=None +): if col_names is None: col_names = [] @@ -150,7 +152,7 @@ def mask_missing_values_in_data(values, missing_fill_value=None): ) impute_message = ( - f"Provided data contains missing values and" + "Provided data contains missing values and" " will be automatically imputed as hidden states" " during Kalman filtering." ) @@ -170,7 +172,9 @@ def register_data_with_pymc( elif isinstance(data, (pd.DataFrame, pd.Series)): values, index = preprocess_pandas_data(data, n_obs, obs_coords) else: - raise ValueError("Data should be one of pytensor tensor, numpy array, or pandas dataframe") + raise ValueError( + "Data should be one of pytensor tensor, numpy array, or pandas dataframe" + ) data, nan_mask = mask_missing_values_in_data(values, missing_fill_value) diff --git a/pymc_experimental/utils/linear_cg.py b/pymc_experimental/utils/linear_cg.py index 49457cad5..c5ba3a63b 100644 --- a/pymc_experimental/utils/linear_cg.py +++ b/pymc_experimental/utils/linear_cg.py @@ -26,7 +26,14 @@ def masked_fill(vector, mask, fill_value): def linear_cg_updates( - result, alpha, residual_inner_prod, eps, beta, residual, precond_residual, curr_conjugate_vec + result, + alpha, + residual_inner_prod, + eps, + beta, + residual, + precond_residual, + curr_conjugate_vec, ): # Everything inside _jit_linear_cg_updates result = result + alpha * curr_conjugate_vec @@ -65,13 +72,16 @@ def linear_cg( initial_guess=None, preconditioner=None, terminate_cg_by_size=False, - use_eval_tolerange=False, + use_eval_tolerance=False, ): if initial_guess is None: initial_guess = np.zeros_like(rhs) if preconditioner is None: - preconditioner = lambda x: x + + def preconditioner(x): + return x + precond = False else: precond = True @@ -112,7 +122,9 @@ def linear_cg( result = np.copy(initial_guess) if not np.allclose(residual, residual): - raise RuntimeError("NaNs encountered when trying to perform matrix-vector multiplication") + raise RuntimeError( + "NaNs encountered when trying to perform matrix-vector multiplication" + ) # sometimes we are lucky and preconditioner solves the system right away # check for convergence @@ -128,7 +140,7 @@ def linear_cg( residual_inner_prod = residual.T @ precond_residual # define storage matrices - mul_storage = np.zeros_like(residual) + np.zeros_like(residual) alpha = np.zeros((*batch_shape, 1, rhs.shape[-1])) beta = np.zeros_like(alpha) is_zero = np.zeros((*batch_shape, 1, rhs.shape[-1])) @@ -239,14 +251,20 @@ def linear_cg( beta_tridiag = np.copy(beta) alpha_tridiag_is_zero = alpha_tridiag == 0 - alpha_tridiag = masked_fill(alpha_tridiag, mask=alpha_tridiag_is_zero, fill_value=1) + alpha_tridiag = masked_fill( + alpha_tridiag, mask=alpha_tridiag_is_zero, fill_value=1 + ) alpha_reciprocal = 1 / alpha_tridiag - alpha_tridiag = masked_fill(alpha_tridiag, mask=alpha_tridiag_is_zero, fill_value=0) + alpha_tridiag = masked_fill( + alpha_tridiag, mask=alpha_tridiag_is_zero, fill_value=0 + ) if k == 0: t_mat[k, k] = alpha_reciprocal else: - t_mat[k, k] += np.squeeze(alpha_reciprocal + prev_beta * prev_alpha_reciprocal) + t_mat[k, k] += np.squeeze( + alpha_reciprocal + prev_beta * prev_alpha_reciprocal + ) t_mat[k, k - 1] = np.sqrt(prev_beta) * prev_alpha_reciprocal t_mat[k - 1, k] = np.copy(t_mat[k, k - 1]) diff --git a/pymc_experimental/utils/pivoted_cholesky.py b/pymc_experimental/utils/pivoted_cholesky.py index 69ea9cd7b..1d436fea7 100644 --- a/pymc_experimental/utils/pivoted_cholesky.py +++ b/pymc_experimental/utils/pivoted_cholesky.py @@ -6,7 +6,9 @@ import numpy as np -pp = lambda x: np.array2string(x, precision=4, floatmode="fixed") + +def pp(x): + return np.array2string(x, precision=4, floatmode="fixed") def pivoted_cholesky(mat: np.matrix, error_tol=1e-6, max_iter=np.inf): diff --git a/pymc_experimental/utils/prior.py b/pymc_experimental/utils/prior.py index 30d4e9507..00f5f6aa9 100644 --- a/pymc_experimental/utils/prior.py +++ b/pymc_experimental/utils/prior.py @@ -44,7 +44,9 @@ class FlatInfo(TypedDict): info: List[VarInfo] -def _arg_to_param_cfg(key, value: Optional[Union[ParamCfg, Transform, str, Tuple]] = None): +def _arg_to_param_cfg( + key, value: Optional[Union[ParamCfg, Transform, str, Tuple]] = None +): if value is None: cfg = ParamCfg(name=key, transform=None, dims=None) elif isinstance(value, Tuple): @@ -133,7 +135,7 @@ def prior_from_idata( name="trace_prior_", *, var_names: Sequence[str] = (), - **kwargs: Union[ParamCfg, Transform, str, Tuple] + **kwargs: Union[ParamCfg, Transform, str, Tuple], ) -> Dict[str, pt.TensorVariable]: """ Create a prior from posterior using MvNormal approximation. diff --git a/pymc_experimental/utils/spline.py b/pymc_experimental/utils/spline.py index 2e4db5e75..c231ea645 100644 --- a/pymc_experimental/utils/spline.py +++ b/pymc_experimental/utils/spline.py @@ -41,11 +41,13 @@ def __init__(self, sparse=True) -> None: def make_node(self, *inputs) -> Apply: eval_points, k, d = map(pt.as_tensor, inputs) - if not (eval_points.ndim == 1 and np.issubdtype(eval_points.dtype, np.floating)): + if not ( + eval_points.ndim == 1 and np.issubdtype(eval_points.dtype, np.floating) + ): raise TypeError("eval_points should be a vector of floats") - if not k.type in pt.int_types: + if k.type not in pt.int_types: raise TypeError("k should be integer") - if not d.type in pt.int_types: + if d.type not in pt.int_types: raise TypeError("degree should be integer") if self.sparse: out_type = ps.SparseTensorType("csr", eval_points.dtype)() diff --git a/pymc_experimental/version.py b/pymc_experimental/version.py index cc0aacca5..6ecd179e4 100644 --- a/pymc_experimental/version.py +++ b/pymc_experimental/version.py @@ -2,7 +2,9 @@ def get_version(): - version_file = os.path.join(os.path.dirname(os.path.abspath(__file__)), "version.txt") + version_file = os.path.join( + os.path.dirname(os.path.abspath(__file__)), "version.txt" + ) with open(version_file) as f: version = f.read().strip() return version diff --git a/setup.py b/setup.py index 92c0ea397..273257c98 100644 --- a/setup.py +++ b/setup.py @@ -13,6 +13,7 @@ # limitations under the License. import itertools +import os from codecs import open from os.path import dirname, join, realpath @@ -60,13 +61,11 @@ dask_histogram=["dask[complete]", "xhistogram"], histogram=["xhistogram"], ) -extras_require["complete"] = sorted(set(itertools.chain.from_iterable(extras_require.values()))) +extras_require["complete"] = sorted( + set(itertools.chain.from_iterable(extras_require.values())) +) extras_require["dev"] = dev_install_reqs -import os - -from setuptools import find_packages, setup - def read_version(): here = os.path.abspath(os.path.dirname(__file__)) diff --git a/setupegg.py b/setupegg.py index 888a65c9b..168c05b99 100755 --- a/setupegg.py +++ b/setupegg.py @@ -17,7 +17,5 @@ A setup.py script to use setuptools, which gives egg goodness, etc. """ -from setuptools import setup - with open("setup.py") as s: exec(s.read()) diff --git a/tests/distributions/__init__.py b/tests/distributions/__init__.py index fa2a64480..d4b13bfcb 100644 --- a/tests/distributions/__init__.py +++ b/tests/distributions/__init__.py @@ -15,3 +15,5 @@ from pymc_experimental.distributions import histogram_utils from pymc_experimental.distributions.histogram_utils import histogram_approximation + +__all__ = ["histogram_utils", "histogram_approximation"] diff --git a/tests/distributions/test_continuous.py b/tests/distributions/test_continuous.py index ced0745b3..97832a35c 100644 --- a/tests/distributions/test_continuous.py +++ b/tests/distributions/test_continuous.py @@ -85,7 +85,13 @@ def ref_logcdf(value, mu, sigma, xi): "mu, sigma, xi, size, expected", [ (0, 1, 0, None, 0), - (1, np.arange(1, 4), 0.1, None, 1 + np.arange(1, 4) * (1.1**-0.1 - 1) / 0.1), + ( + 1, + np.arange(1, 4), + 0.1, + None, + 1 + np.arange(1, 4) * (1.1**-0.1 - 1) / 0.1, + ), (np.arange(5), 1, 0.1, None, np.arange(5) + (1.1**-0.1 - 1) / 0.1), ( 0, @@ -105,7 +111,10 @@ def ref_logcdf(value, mu, sigma, xi): (3, 6), np.arange(6) + np.arange(1, 7) - * ((1 + np.linspace(-0.2, 0.2, 6)) ** -np.linspace(-0.2, 0.2, 6) - 1) + * ( + (1 + np.linspace(-0.2, 0.2, 6)) ** -np.linspace(-0.2, 0.2, 6) + - 1 + ) / np.linspace(-0.2, 0.2, 6), ), ), diff --git a/tests/distributions/test_discrete.py b/tests/distributions/test_discrete.py index 60885908f..716c703f1 100644 --- a/tests/distributions/test_discrete.py +++ b/tests/distributions/test_discrete.py @@ -104,14 +104,18 @@ def test_logp_lam_expected_moments(self): mu = 30 lam = np.array([-0.9, -0.7, -0.2, 0, 0.2, 0.7, 0.9]) with pm.Model(): - x = GeneralizedPoisson("x", mu=mu, lam=lam) + GeneralizedPoisson("x", mu=mu, lam=lam) trace = pm.sample(chains=1, draws=10_000, random_seed=96).posterior expected_mean = mu / (1 - lam) - np.testing.assert_allclose(trace["x"].mean(("chain", "draw")), expected_mean, rtol=1e-1) + np.testing.assert_allclose( + trace["x"].mean(("chain", "draw")), expected_mean, rtol=1e-1 + ) expected_std = np.sqrt(mu / (1 - lam) ** 3) - np.testing.assert_allclose(trace["x"].std(("chain", "draw")), expected_std, rtol=1e-1) + np.testing.assert_allclose( + trace["x"].std(("chain", "draw")), expected_std, rtol=1e-1 + ) @pytest.mark.parametrize( "mu, lam, size, expected", diff --git a/tests/distributions/test_discrete_markov_chain.py b/tests/distributions/test_discrete_markov_chain.py index 0e55319de..a2585fd8b 100644 --- a/tests/distributions/test_discrete_markov_chain.py +++ b/tests/distributions/test_discrete_markov_chain.py @@ -23,10 +23,15 @@ def transition_probability_tests(steps, n_states, n_lags, n_draws, atol): # Test x0 is uniform over n_states for i in range(n_lags): assert np.allclose( - np.histogram(draws[:, ..., i], bins=n_states)[0] / n_draws, 1 / n_states, atol=atol + np.histogram(draws[:, ..., i], bins=n_states)[0] / n_draws, + 1 / n_states, + atol=atol, ) - n_grams = [[tuple(row[i : i + n_lags + 1]) for i in range(len(row) - n_lags)] for row in draws] + n_grams = [ + [tuple(row[i : i + n_lags + 1]) for i in range(len(row) - n_lags)] + for row in draws + ] freq_table = np.zeros((n_states,) * (n_lags + 1)) for row in n_grams: @@ -60,16 +65,20 @@ def test_high_dimensional_P(self): x0 = pm.Categorical.dist(p=np.ones(3) / 3) chain = DiscreteMarkovChain.dist(P=P, steps=10, init_dist=x0, n_lags=n_lags) draws = pm.draw(chain, 10) - logp = pm.logp(chain, draws) + pm.logp(chain, draws) def test_default_init_dist_warns_user(self): - P = pt.as_tensor_variable(np.array([[0.1, 0.5, 0.4], [0.3, 0.4, 0.3], [0.9, 0.05, 0.05]])) + P = pt.as_tensor_variable( + np.array([[0.1, 0.5, 0.4], [0.3, 0.4, 0.3], [0.9, 0.05, 0.05]]) + ) with pytest.warns(UserWarning): DiscreteMarkovChain.dist(P=P, steps=3) def test_logp_shape(self): - P = pt.as_tensor_variable(np.array([[0.1, 0.5, 0.4], [0.3, 0.4, 0.3], [0.9, 0.05, 0.05]])) + P = pt.as_tensor_variable( + np.array([[0.1, 0.5, 0.4], [0.3, 0.4, 0.3], [0.9, 0.05, 0.05]]) + ) x0 = pm.Categorical.dist(p=np.ones(3) / 3) # Test with steps @@ -87,7 +96,9 @@ def test_logp_shape(self): assert logp.shape == (5,) def test_logp_with_default_init_dist(self): - P = pt.as_tensor_variable(np.array([[0.1, 0.5, 0.4], [0.3, 0.4, 0.3], [0.9, 0.05, 0.05]])) + P = pt.as_tensor_variable( + np.array([[0.1, 0.5, 0.4], [0.3, 0.4, 0.3], [0.9, 0.05, 0.05]]) + ) x0 = pm.Categorical.dist(p=np.ones(3) / 3) value = np.array([0, 1, 2]) @@ -105,7 +116,9 @@ def test_logp_with_default_init_dist(self): np.testing.assert_allclose(model_logp_eval, logp_expected, rtol=1e-6) def test_logp_with_user_defined_init_dist(self): - P = pt.as_tensor_variable(np.array([[0.1, 0.5, 0.4], [0.3, 0.4, 0.3], [0.9, 0.05, 0.05]])) + P = pt.as_tensor_variable( + np.array([[0.1, 0.5, 0.4], [0.3, 0.4, 0.3], [0.9, 0.05, 0.05]]) + ) x0 = pm.Categorical.dist(p=[0.2, 0.6, 0.2]) chain = DiscreteMarkovChain.dist(P=P, init_dist=x0, steps=3) @@ -161,7 +174,9 @@ def test_dims_when_steps_are_defined(self): P = pt.full((3, 3), 1 / 3) x0 = pm.Categorical.dist(p=np.ones(3) / 3) - chain = DiscreteMarkovChain("chain", P=P, steps=3, init_dist=x0, dims=["steps"]) + chain = DiscreteMarkovChain( + "chain", P=P, steps=3, init_dist=x0, dims=["steps"] + ) assert chain.eval().shape == (4,) @@ -196,16 +211,24 @@ def test_multiple_lags_with_data(self): x0 = pm.Categorical.dist(p=[0.1, 0.1, 0.8], size=2) data = pm.draw(x0, 100) - chain = DiscreteMarkovChain("chain", P=P, init_dist=x0, n_lags=2, observed=data) + chain = DiscreteMarkovChain( + "chain", P=P, init_dist=x0, n_lags=2, observed=data + ) assert chain.eval().shape == (100, 2) def test_random_draws(self): - transition_probability_tests(steps=3, n_states=2, n_lags=1, n_draws=2500, atol=0.05) - transition_probability_tests(steps=3, n_states=2, n_lags=3, n_draws=7500, atol=0.05) + transition_probability_tests( + steps=3, n_states=2, n_lags=1, n_draws=2500, atol=0.05 + ) + transition_probability_tests( + steps=3, n_states=2, n_lags=3, n_draws=7500, atol=0.05 + ) def test_change_size_univariate(self): - P = pt.as_tensor_variable(np.array([[0.1, 0.5, 0.4], [0.3, 0.4, 0.3], [0.9, 0.05, 0.05]])) + P = pt.as_tensor_variable( + np.array([[0.1, 0.5, 0.4], [0.3, 0.4, 0.3], [0.9, 0.05, 0.05]]) + ) x0 = pm.Categorical.dist(p=np.ones(3) / 3) chain = DiscreteMarkovChain.dist(P=P, init_dist=x0, shape=(100, 5)) diff --git a/tests/distributions/test_multivariate.py b/tests/distributions/test_multivariate.py index f0ecfa98e..a48b1293f 100644 --- a/tests/distributions/test_multivariate.py +++ b/tests/distributions/test_multivariate.py @@ -9,7 +9,9 @@ class TestR2D2M2CP: @pytest.fixture(autouse=True) def fast_compile(self): - with pytensor.config.change_flags(mode="FAST_COMPILE", exception_verbosity="high"): + with pytensor.config.change_flags( + mode="FAST_COMPILE", exception_verbosity="high" + ): yield @pytest.fixture(autouse=True) @@ -123,7 +125,9 @@ def test_init_r2( assert beta.eval().shape == input_std.shape # r2 rv is only created if r2 std is not None assert "beta" in model.named_vars - assert ("beta::r2" in model.named_vars) == (r2_std is not None), set(model.named_vars) + assert ("beta::r2" in model.named_vars) == (r2_std is not None), set( + model.named_vars + ) # phi is only created if variable importance is not None and there is more than one var assert np.isfinite(model.compile_logp()(model.initial_point())) @@ -221,11 +225,18 @@ def test_failing_variance_explained(self, dims, input_shape, output_std, input_s ) else: pmx.distributions.R2D2M2CP( - "beta", output_std, input_std, dims=dims, r2=0.8, variance_explained=abs(input_std) + "beta", + output_std, + input_std, + dims=dims, + r2=0.8, + variance_explained=abs(input_std), ) def test_failing_mutual_exclusive(self, model: pm.Model): - with pytest.raises(TypeError, match="variable importance with variance explained"): + with pytest.raises( + TypeError, match="variable importance with variance explained" + ): with model: model.add_coord("a", range(2)) pmx.distributions.R2D2M2CP( @@ -295,10 +306,18 @@ def test_limit_case_creates_masked_vars(self, model: pm.Model, centered: bool): def test_zero_length_rvs_not_created(self, model: pm.Model): model.add_coord("a", range(2)) # deterministic case which should not have any new variables - b = pmx.distributions.R2D2M2CP("b1", 1, [1, 1], r2=0.5, positive_probs=[1, 1], dims="a") + pmx.distributions.R2D2M2CP( + "b1", 1, [1, 1], r2=0.5, positive_probs=[1, 1], dims="a" + ) assert not model.free_RVs, model.free_RVs - b = pmx.distributions.R2D2M2CP( - "b2", 1, [1, 1], r2=0.5, positive_probs=[1, 1], positive_probs_std=[0, 0], dims="a" + pmx.distributions.R2D2M2CP( + "b2", + 1, + [1, 1], + r2=0.5, + positive_probs=[1, 1], + positive_probs_std=[0, 0], + dims="a", ) assert not model.free_RVs, model.free_RVs diff --git a/tests/model/test_marginal_model.py b/tests/model/test_marginal_model.py index fd1ce259c..12892569f 100644 --- a/tests/model/test_marginal_model.py +++ b/tests/model/test_marginal_model.py @@ -41,12 +41,14 @@ def disaster_model(): years = np.arange(1851, 1962) with MarginalModel() as disaster_model: - switchpoint = pm.DiscreteUniform("switchpoint", lower=years.min(), upper=years.max()) + switchpoint = pm.DiscreteUniform( + "switchpoint", lower=years.min(), upper=years.max() + ) early_rate = pm.Exponential("early_rate", 1.0, initval=3) late_rate = pm.Exponential("late_rate", 1.0, initval=1) rate = pm.math.switch(switchpoint >= years, early_rate, late_rate) with pytest.warns(Warning): - disasters = pm.Poisson("disasters", rate, observed=disaster_data) + pm.Poisson("disasters", rate, observed=disaster_data) return disaster_model, years @@ -98,7 +100,7 @@ def test_marginalized_basic(): ), ) y = pm.Normal("y", mu=mu, sigma=sigma) - z = pm.Normal("z", y, observed=data) + pm.Normal("z", y, observed=data) m.marginalize([idx]) assert idx not in m.free_RVs @@ -108,7 +110,7 @@ def test_marginalized_basic(): with pm.Model() as m_ref: sigma = pm.HalfNormal("sigma") y = pm.NormalMixture("y", w=[0.1, 0.3, 0.6], mu=[-1, 0, 1], sigma=sigma) - z = pm.Normal("z", y, observed=data) + pm.Normal("z", y, observed=data) test_point = m_ref.initial_point() ref_logp = m_ref.compile_logp()(test_point) @@ -131,9 +133,9 @@ def test_multiple_independent_marginalized_rvs(): with MarginalModel() as m: sigma = pm.HalfNormal("sigma") idx1 = pm.Bernoulli("idx1", p=0.75) - x = pm.Normal("x", mu=idx1, sigma=sigma) + pm.Normal("x", mu=idx1, sigma=sigma) idx2 = pm.Bernoulli("idx2", p=0.75, shape=(5,)) - y = pm.Normal("y", mu=(idx2 * 2 - 1), sigma=sigma, shape=(5,)) + pm.Normal("y", mu=(idx2 * 2 - 1), sigma=sigma, shape=(5,)) m.marginalize([idx1, idx2]) m["x"].owner is not m["y"].owner @@ -142,13 +144,15 @@ def test_multiple_independent_marginalized_rvs(): with pm.Model() as m_ref: sigma = pm.HalfNormal("sigma") - x = pm.NormalMixture("x", w=[0.25, 0.75], mu=[0, 1], sigma=sigma) - y = pm.NormalMixture("y", w=[0.25, 0.75], mu=[-1, 1], sigma=sigma, shape=(5,)) + pm.NormalMixture("x", w=[0.25, 0.75], mu=[0, 1], sigma=sigma) + pm.NormalMixture("y", w=[0.25, 0.75], mu=[-1, 1], sigma=sigma, shape=(5,)) # Test logp test_point = m_ref.initial_point() x_logp, y_logp = m.compile_logp(vars=[m["x"], m["y"]], sum=False)(test_point) - x_ref_log, y_ref_logp = m_ref.compile_logp(vars=[m_ref["x"], m_ref["y"]], sum=False)(test_point) + x_ref_log, y_ref_logp = m_ref.compile_logp( + vars=[m_ref["x"], m_ref["y"]], sum=False + )(test_point) np.testing.assert_array_almost_equal(x_logp, x_ref_log.sum()) np.testing.assert_array_almost_equal(y_logp, y_ref_logp) @@ -172,7 +176,9 @@ def test_multiple_dependent_marginalized_rvs(): _m["x"].owner is _m["y"].owner tp = m.initial_point() - ref_logp_x_y = logsumexp([ref_logp_x_y_fn({**tp, **{"idx": idx}}) for idx in (0, 1)]) + ref_logp_x_y = logsumexp( + [ref_logp_x_y_fn({**tp, **{"idx": idx}}) for idx in (0, 1)] + ) logp_x_y = m.compile_logp([x, y])(tp) np.testing.assert_array_almost_equal(logp_x_y, ref_logp_x_y) @@ -182,7 +188,7 @@ def test_rv_dependent_multiple_marginalized_rvs(): with MarginalModel() as m: x = pm.Bernoulli("x", 0.1) y = pm.Bernoulli("y", 0.3) - z = pm.DiracDelta("z", c=x + y) + pm.DiracDelta("z", c=x + y) m.marginalize([x, y]) logp = m.compile_logp() @@ -200,9 +206,13 @@ def test_nested_marginalized_rvs(): sigma = pm.HalfNormal("sigma") idx = pm.Bernoulli("idx", p=0.75) - dep = pm.Normal("dep", mu=pt.switch(pt.eq(idx, 0), -1000.0, 1000.0), sigma=sigma) + dep = pm.Normal( + "dep", mu=pt.switch(pt.eq(idx, 0), -1000.0, 1000.0), sigma=sigma + ) - sub_idx = pm.Bernoulli("sub_idx", p=pt.switch(pt.eq(idx, 0), 0.15, 0.95), shape=(5,)) + sub_idx = pm.Bernoulli( + "sub_idx", p=pt.switch(pt.eq(idx, 0), 0.15, 0.95), shape=(5,) + ) sub_dep = pm.Normal("sub_dep", mu=dep + sub_idx * 100, sigma=sigma, shape=(5,)) ref_logp_fn = m.compile_logp(vars=[idx, dep, sub_idx, sub_dep]) @@ -239,7 +249,9 @@ def test_marginalized_change_point_model(disaster_model): ref_logp_fn = m.compile_logp( [m["switchpoint"], m["disasters_observed"], m["disasters_unobserved"]] ) - ref_logp = logsumexp([ref_logp_fn({**ip, **{"switchpoint": year}}) for year in years]) + ref_logp = logsumexp( + [ref_logp_fn({**ip, **{"switchpoint": year}}) for year in years] + ) with pytest.warns(UserWarning, match="There are multiple dependent variables"): m.marginalize(m["switchpoint"]) @@ -256,13 +268,17 @@ def test_marginalized_change_point_model_sampling(disaster_model): rng = np.random.default_rng(211) with m: - before_marg = pm.sample(chains=2, random_seed=rng).posterior.stack(sample=("draw", "chain")) + before_marg = pm.sample(chains=2, random_seed=rng).posterior.stack( + sample=("draw", "chain") + ) with pytest.warns(UserWarning, match="There are multiple dependent variables"): m.marginalize([m["switchpoint"]]) with m: - after_marg = pm.sample(chains=2, random_seed=rng).posterior.stack(sample=("draw", "chain")) + after_marg = pm.sample(chains=2, random_seed=rng).posterior.stack( + sample=("draw", "chain") + ) np.testing.assert_allclose( before_marg["early_rate"].mean(), after_marg["early_rate"].mean(), rtol=1e-2 @@ -284,7 +300,7 @@ def test_recover_marginals_basic(): k = pm.Categorical("k", p=p) mu = np.array([-3.0, 0.0, 3.0]) mu_ = pt.as_tensor_variable(mu) - y = pm.Normal("y", mu=mu_[k], sigma=sigma) + pm.Normal("y", mu=mu_[k], sigma=sigma) m.marginalize([k]) @@ -328,7 +344,7 @@ def test_recover_marginals_coords(): with MarginalModel(coords={"year": [1990, 1991, 1992]}) as m: sigma = pm.HalfNormal("sigma") idx = pm.Bernoulli("idx", p=0.75, dims="year") - x = pm.Normal("x", mu=idx, sigma=sigma, dims="year") + pm.Normal("x", mu=idx, sigma=sigma, dims="year") m.marginalize([idx]) rng = np.random.default_rng(211) @@ -340,7 +356,9 @@ def test_recover_marginals_coords(): return_inferencedata=False, ) idata = InferenceData( - posterior=dict_to_dataset({k: np.expand_dims(prior[k], axis=0) for k in prior}) + posterior=dict_to_dataset( + {k: np.expand_dims(prior[k], axis=0) for k in prior} + ) ) idata = m.recover_marginals(idata, return_samples=True) @@ -354,7 +372,7 @@ def test_recover_batched_marginal(): with MarginalModel() as m: sigma = pm.HalfNormal("sigma") idx = pm.Bernoulli("idx", p=0.7, shape=(3, 2)) - y = pm.Normal("y", mu=idx, sigma=sigma, shape=(3, 2)) + pm.Normal("y", mu=idx, sigma=sigma, shape=(3, 2)) m.marginalize([idx]) @@ -367,7 +385,9 @@ def test_recover_batched_marginal(): return_inferencedata=False, ) idata = InferenceData( - posterior=dict_to_dataset({k: np.expand_dims(prior[k], axis=0) for k in prior}) + posterior=dict_to_dataset( + {k: np.expand_dims(prior[k], axis=0) for k in prior} + ) ) idata = m.recover_marginals(idata, return_samples=True) @@ -385,7 +405,7 @@ def test_nested_recover_marginals(): with MarginalModel() as m: idx = pm.Bernoulli("idx", p=0.75) sub_idx = pm.Bernoulli("sub_idx", p=pt.switch(pt.eq(idx, 0), 0.15, 0.95)) - sub_dep = pm.Normal("y", mu=idx + sub_idx, sigma=1.0) + pm.Normal("y", mu=idx + sub_idx, sigma=1.0) m.marginalize([idx, sub_idx]) @@ -411,8 +431,12 @@ def test_nested_recover_marginals(): assert post.lp_sub_idx.shape == post.sub_idx.shape + (2,) def true_idx_logp(y): - idx_0 = np.log(0.85 * 0.25 * norm.pdf(y, loc=0) + 0.15 * 0.25 * norm.pdf(y, loc=1)) - idx_1 = np.log(0.05 * 0.75 * norm.pdf(y, loc=1) + 0.95 * 0.75 * norm.pdf(y, loc=2)) + idx_0 = np.log( + 0.85 * 0.25 * norm.pdf(y, loc=0) + 0.15 * 0.25 * norm.pdf(y, loc=1) + ) + idx_1 = np.log( + 0.05 * 0.75 * norm.pdf(y, loc=1) + 0.95 * 0.75 * norm.pdf(y, loc=2) + ) return log_softmax(np.stack([idx_0, idx_1]).T, axis=1) np.testing.assert_almost_equal( @@ -421,8 +445,12 @@ def true_idx_logp(y): ) def true_sub_idx_logp(y): - sub_idx_0 = np.log(0.85 * 0.25 * norm.pdf(y, loc=0) + 0.05 * 0.75 * norm.pdf(y, loc=1)) - sub_idx_1 = np.log(0.15 * 0.25 * norm.pdf(y, loc=1) + 0.95 * 0.75 * norm.pdf(y, loc=2)) + sub_idx_0 = np.log( + 0.85 * 0.25 * norm.pdf(y, loc=0) + 0.05 * 0.75 * norm.pdf(y, loc=1) + ) + sub_idx_1 = np.log( + 0.15 * 0.25 * norm.pdf(y, loc=1) + 0.95 * 0.75 * norm.pdf(y, loc=2) + ) return log_softmax(np.stack([sub_idx_0, sub_idx_1]).T, axis=1) np.testing.assert_almost_equal( @@ -443,14 +471,14 @@ def test_not_supported_marginalized(): with MarginalModel() as m: p = pm.Beta("p", 1, 1) idx = pm.Bernoulli("idx", p=p, size=2) - y = pm.Normal("y", mu=pm.math.switch(idx, 0, 1)) + pm.Normal("y", mu=pm.math.switch(idx, 0, 1)) m.marginalize([idx]) # ALlowed, as index operation does not connext idx to y with MarginalModel() as m: p = pm.Beta("p", 1, 1) idx = pm.Bernoulli("idx", p=p, size=2) - y = pm.Normal("y", mu=pm.math.switch(idx, mu[0], mu[1])) + pm.Normal("y", mu=pm.math.switch(idx, mu[0], mu[1])) m.marginalize([idx]) # Not allowed, as index operation connects idx to y @@ -458,7 +486,7 @@ def test_not_supported_marginalized(): p = pm.Beta("p", 1, 1) idx = pm.Bernoulli("idx", p=p, size=2) # Not allowed - y = pm.Normal("y", mu=mu[idx]) + pm.Normal("y", mu=mu[idx]) with pytest.raises(NotImplementedError): m.marginalize(idx) @@ -467,14 +495,14 @@ def test_not_supported_marginalized(): with MarginalModel() as m: p = pm.Beta("p", 1, 1) idx = pm.Bernoulli("idx", p=p, size=2) - y = pm.Normal("y", mu=mu[idx] + idx) + pm.Normal("y", mu=mu[idx] + idx) with pytest.raises(NotImplementedError): m.marginalize(idx) # Multivariate dependent RVs not supported with MarginalModel() as m: x = pm.Bernoulli("x", p=0.7) - y = pm.Dirichlet("y", a=pm.math.switch(x, [1, 1, 1], [10, 10, 10])) + pm.Dirichlet("y", a=pm.math.switch(x, [1, 1, 1], [10, 10, 10])) with pytest.raises( NotImplementedError, match="Marginalization with dependent Multivariate RVs not implemented", @@ -496,7 +524,9 @@ def test_marginalized_deterministic_and_potential(): with pytest.warns(UserWarning, match="There are multiple dependent variables"): m.marginalize([x]) - y_draw, z_draw, det_draw, pot_draw = pm.draw([y, z, det, pot], draws=5, random_seed=rng) + y_draw, z_draw, det_draw, pot_draw = pm.draw( + [y, z, det, pot], draws=5, random_seed=rng + ) np.testing.assert_almost_equal(y_draw + z_draw, det_draw) np.testing.assert_almost_equal(det_draw, pot_draw - 1) @@ -513,17 +543,18 @@ def test_not_supported_marginalized_deterministic_and_potential(): with MarginalModel() as m: x = pm.Bernoulli("x", p=0.7) y = pm.Normal("y", x) - det = pm.Deterministic("det", x + y) + pm.Deterministic("det", x + y) with pytest.raises( - NotImplementedError, match="Cannot marginalize x due to dependent Deterministic det" + NotImplementedError, + match="Cannot marginalize x due to dependent Deterministic det", ): m.marginalize([x]) with MarginalModel() as m: x = pm.Bernoulli("x", p=0.7) y = pm.Normal("y", x) - pot = pm.Potential("pot", x + y) + pm.Potential("pot", x + y) with pytest.raises( NotImplementedError, match="Cannot marginalize x due to dependent Potential pot" @@ -542,13 +573,15 @@ def test_not_supported_marginalized_deterministic_and_potential(): ( transforms.Interval(0, 1), pytest.warns( - UserWarning, match="which depends on the marginalized idx may no longer work" + UserWarning, + match="which depends on the marginalized idx may no longer work", ), ), ( transforms.Chain([transforms.log, transforms.Interval(0, 1)]), pytest.warns( - UserWarning, match="which depends on the marginalized idx may no longer work" + UserWarning, + match="which depends on the marginalized idx may no longer work", ), ), ), @@ -566,7 +599,7 @@ def test_marginalized_transforms(transform, expected_warning): initval=initval, default_transform=transform, ) - y = pm.Normal("y", 0, sigma, observed=data) + pm.Normal("y", 0, sigma, observed=data) with MarginalModel() as m: idx = pm.Categorical("idx", p=w) @@ -584,7 +617,7 @@ def test_marginalized_transforms(transform, expected_warning): initval=initval, default_transform=transform, ) - y = pm.Normal("y", 0, sigma, observed=data) + pm.Normal("y", 0, sigma, observed=data) with expected_warning: m.marginalize([idx]) @@ -615,7 +648,7 @@ def test_data_container(): with MarginalModel(coords={"obs": [0]}) as marginal_m: x = pm.Data("x", 2.5) idx = pm.Bernoulli("idx", p=0.7, dims="obs") - y = pm.Normal("y", idx * x, dims="obs") + pm.Normal("y", idx * x, dims="obs") marginal_m.marginalize([idx]) @@ -623,7 +656,7 @@ def test_data_container(): with pm.Model(coords={"obs": [0]}) as m_ref: x = pm.Data("x", 2.5) - y = pm.NormalMixture("y", w=[0.3, 0.7], mu=[0, x], dims="obs") + pm.NormalMixture("y", w=[0.3, 0.7], mu=[0, x], dims="obs") ref_logp_fn = m_ref.compile_logp() @@ -638,7 +671,6 @@ def test_data_container(): @pytest.mark.parametrize("univariate", (True, False)) def test_vector_univariate_mixture(univariate): - with MarginalModel() as m: idx = pm.Bernoulli("idx", p=0.5, shape=(2,) if univariate else ()) @@ -656,7 +688,9 @@ def dist(idx, size): if univariate: with pm.Model() as ref_m: - pm.NormalMixture("norm", w=[0.5, 0.5], mu=[[-10, 10], [-10, 10]], shape=(2,)) + pm.NormalMixture( + "norm", w=[0.5, 0.5], mu=[[-10, 10], [-10, 10]], shape=(2,) + ) else: with pm.Model() as ref_m: pm.Mixture( @@ -681,7 +715,9 @@ def dist(idx, size): @pytest.mark.parametrize("batch_chain", (False, True), ids=lambda x: f"batch_chain={x}") -@pytest.mark.parametrize("batch_emission", (False, True), ids=lambda x: f"batch_emission={x}") +@pytest.mark.parametrize( + "batch_emission", (False, True), ids=lambda x: f"batch_emission={x}" +) def test_marginalized_hmm_normal_emission(batch_chain, batch_emission): if batch_chain and not batch_emission: pytest.skip("Redundant implicit combination") @@ -690,21 +726,30 @@ def test_marginalized_hmm_normal_emission(batch_chain, batch_emission): P = [[0, 1], [1, 0]] init_dist = pm.Categorical.dist(p=[1, 0]) chain = DiscreteMarkovChain( - "chain", P=P, init_dist=init_dist, steps=3, shape=(3, 4) if batch_chain else None + "chain", + P=P, + init_dist=init_dist, + steps=3, + shape=(3, 4) if batch_chain else None, ) - emission = pm.Normal( - "emission", mu=chain * 2 - 1, sigma=1e-1, shape=(3, 4) if batch_emission else None + pm.Normal( + "emission", + mu=chain * 2 - 1, + sigma=1e-1, + shape=(3, 4) if batch_emission else None, ) m.marginalize([chain]) logp_fn = m.compile_logp() test_value = np.array([-1, 1, -1, 1]) - expected_logp = pm.logp(pm.Normal.dist(0, 1e-1), np.zeros_like(test_value)).sum().eval() + expected_logp = ( + pm.logp(pm.Normal.dist(0, 1e-1), np.zeros_like(test_value)).sum().eval() + ) if batch_emission: test_value = np.broadcast_to(test_value, (3, 4)) expected_logp *= 3 - np.testing.assert_allclose(logp_fn({f"emission": test_value}), expected_logp) + np.testing.assert_allclose(logp_fn({"emission": test_value}), expected_logp) @pytest.mark.parametrize( @@ -718,17 +763,18 @@ def test_marginalized_hmm_categorical_emission(categorical_emission): init_dist = pm.Categorical.dist(p=[0.375, 0.625]) chain = DiscreteMarkovChain("chain", P=P, init_dist=init_dist, steps=2) if categorical_emission: - emission = pm.Categorical( - "emission", p=pt.where(pt.eq(chain, 0)[..., None], [0.8, 0.2], [0.4, 0.6]) + pm.Categorical( + "emission", + p=pt.where(pt.eq(chain, 0)[..., None], [0.8, 0.2], [0.4, 0.6]), ) else: - emission = pm.Bernoulli("emission", p=pt.where(pt.eq(chain, 0), 0.2, 0.6)) + pm.Bernoulli("emission", p=pt.where(pt.eq(chain, 0), 0.2, 0.6)) m.marginalize([chain]) test_value = np.array([0, 0, 1]) expected_logp = np.log(0.1344) # Shown at the 10m22s mark in the video logp_fn = m.compile_logp() - np.testing.assert_allclose(logp_fn({f"emission": test_value}), expected_logp) + np.testing.assert_allclose(logp_fn({"emission": test_value}), expected_logp) @pytest.mark.parametrize("batch_emission1", (False, True)) @@ -740,8 +786,8 @@ def test_marginalized_hmm_multiple_emissions(batch_emission1, batch_emission2): P = [[0, 1], [1, 0]] init_dist = pm.Categorical.dist(p=[1, 0]) chain = DiscreteMarkovChain("chain", P=P, init_dist=init_dist, steps=3) - emission_1 = pm.Normal("emission_1", mu=chain * 2 - 1, sigma=1e-1, shape=emission1_shape) - emission_2 = pm.Normal( + pm.Normal("emission_1", mu=chain * 2 - 1, sigma=1e-1, shape=emission1_shape) + pm.Normal( "emission_2", mu=(1 - chain) * 2 - 1, sigma=1e-1, shape=emission2_shape ) @@ -755,7 +801,10 @@ def test_marginalized_hmm_multiple_emissions(batch_emission1, batch_emission2): expected_logp = norm.logpdf(np.zeros_like(test_value), 0, 1e-1).sum() * multiplier test_value_emission1 = np.broadcast_to(test_value, emission1_shape) test_value_emission2 = np.broadcast_to(-test_value, emission2_shape) - test_point = {"emission_1": test_value_emission1, "emission_2": test_value_emission2} + test_point = { + "emission_1": test_value_emission1, + "emission_2": test_value_emission2, + } np.testing.assert_allclose(logp_fn(test_point), expected_logp) @@ -764,13 +813,15 @@ def test_mutable_indexing_jax_backend(): from pymc.sampling.jax import get_jaxified_logp with MarginalModel() as model: - data = pm.Data(f"data", np.zeros(10)) + data = pm.Data("data", np.zeros(10)) cat_effect = pm.Normal("cat_effect", sigma=1, shape=5) cat_effect_idx = pm.Data("cat_effect_idx", np.array([0, 1] * 5)) is_outlier = pm.Bernoulli("is_outlier", 0.4, shape=10) - pm.LogNormal("y", mu=cat_effect[cat_effect_idx], sigma=1 + is_outlier, observed=data) + pm.LogNormal( + "y", mu=cat_effect[cat_effect_idx], sigma=1 + is_outlier, observed=data + ) model.marginalize(["is_outlier"]) get_jaxified_logp(model) @@ -781,7 +832,7 @@ def create_model(model_class): idx = pm.Bernoulli("idx", p=0.5, dims="trial") mu = pt.where(idx, 1, -1) sigma = pm.HalfNormal("sigma") - y = pm.Normal("y", mu=mu, sigma=sigma, dims="trial", observed=[1] * 10) + pm.Normal("y", mu=mu, sigma=sigma, dims="trial", observed=[1] * 10) return m marginal_m = marginalize(create_model(pm.Model), ["idx"]) @@ -793,7 +844,9 @@ def create_model(model_class): # Check forward graph representation is the same marginal_fgraph, _ = fgraph_from_model(marginal_m) reference_fgraph, _ = fgraph_from_model(reference_m) - assert equal_computations_up_to_root(marginal_fgraph.outputs, reference_fgraph.outputs) + assert equal_computations_up_to_root( + marginal_fgraph.outputs, reference_fgraph.outputs + ) # Check logp graph is the same # This fails because OpFromGraphs comparison is broken diff --git a/tests/model/transforms/test_autoreparam.py b/tests/model/transforms/test_autoreparam.py index 1d2173066..0f9f118a7 100644 --- a/tests/model/transforms/test_autoreparam.py +++ b/tests/model/transforms/test_autoreparam.py @@ -78,12 +78,18 @@ def test_multilevel(): s = pm.HalfNormal("s") a_g = pm.Normal("a_g", a, s, shape=(2,), dims="level") s_g = pm.HalfNormal("s_g") - a_ig = pm.Normal("a_ig", a_g, s_g, shape=(2, 2), dims=("county", "level")) + pm.Normal("a_ig", a_g, s_g, shape=(2, 2), dims=("county", "level")) model_r, vip = vip_reparametrize(model, ["a_g", "a_ig"]) assert "a_g" in vip.get_lambda() assert "a_ig" in vip.get_lambda() - assert {v.name for v in model_r.free_RVs} == {"a", "s", "a_g::tau_", "s_g", "a_ig::tau_"} + assert {v.name for v in model_r.free_RVs} == { + "a", + "s", + "a_g::tau_", + "s_g", + "a_ig::tau_", + } assert "a_g" in [v.name for v in model_r.deterministics] diff --git a/tests/statespace/test_SARIMAX.py b/tests/statespace/test_SARIMAX.py index fe9d8435e..7b5fd6a78 100644 --- a/tests/statespace/test_SARIMAX.py +++ b/tests/statespace/test_SARIMAX.py @@ -17,9 +17,6 @@ SARIMAX_STATE_STRUCTURES, SHORT_NAME_TO_LONG, ) -from tests.statespace.utilities.shared_fixtures import ( # pylint: disable=unused-import - rng, -) from tests.statespace.utilities.test_helpers import ( load_nile_test_data, make_stationary_params, @@ -36,7 +33,16 @@ ["data", "D1.data", "data_star", "state_star_1", "state_star_2"], ["data", "D1.data", "D1^2.data", "data_star", "state_star_1", "state_star_2"], ["data", "state_1", "state_2", "state_3"], - ["data", "state_1", "state_2", "state_3", "state_4", "state_5", "state_6", "state_7"], + [ + "data", + "state_1", + "state_2", + "state_3", + "state_4", + "state_5", + "state_6", + "state_7", + ], [ "data", "state_1", @@ -156,7 +162,10 @@ (2, 1, 2, 2, 0, 2, 12), ] -ids = [f"p={p},d={d},q={q},P={P},D={D},Q={Q},S={S}" for (p, d, q, P, D, Q, S) in test_orders] +ids = [ + f"p={p},d={d},q={q},P={P},D={D},Q={Q},S={S}" + for (p, d, q, P, D, Q, S) in test_orders +] @pytest.fixture @@ -166,7 +175,9 @@ def data(): @pytest.fixture(scope="session") def arima_mod(): - return BayesianSARIMA(order=(2, 0, 1), stationary_initialization=True, verbose=False) + return BayesianSARIMA( + order=(2, 0, 1), stationary_initialization=True, verbose=False + ) @pytest.fixture(scope="session") @@ -177,10 +188,12 @@ def pymc_mod(arima_mod): # x0 = pm.Normal('x0', dims=['state']) # P0_diag = pm.Gamma('P0_diag', alpha=2, beta=1, dims=['state']) # P0 = pm.Deterministic('P0', pt.diag(P0_diag), dims=['state', 'state_aux']) - ar_params = pm.Normal("ar_params", sigma=0.1, dims=["ar_lag"]) - ma_params = pm.Normal("ma_params", sigma=1, dims=["ma_lag"]) - sigma_state = pm.Exponential("sigma_state", 0.5) - arima_mod.build_statespace_graph(data=data, save_kalman_filter_outputs_in_idata=True) + pm.Normal("ar_params", sigma=0.1, dims=["ar_lag"]) + pm.Normal("ma_params", sigma=1, dims=["ma_lag"]) + pm.Exponential("sigma_state", 0.5) + arima_mod.build_statespace_graph( + data=data, save_kalman_filter_outputs_in_idata=True + ) return pymc_mod @@ -201,17 +214,21 @@ def pymc_mod_interp(arima_mod_interp): data = load_nile_test_data() with pm.Model(coords=arima_mod_interp.coords) as pymc_mod: - x0 = pm.Normal("x0", dims=["state"]) + pm.Normal("x0", dims=["state"]) P0_sigma = pm.Exponential("P0_sigma", 1) - P0 = pm.Deterministic( - "P0", pt.eye(arima_mod_interp.k_states) * P0_sigma, dims=["state", "state_aux"] + pm.Deterministic( + "P0", + pt.eye(arima_mod_interp.k_states) * P0_sigma, + dims=["state", "state_aux"], ) - ar_params = pm.Normal("ar_params", sigma=0.1, dims=["ar_lag"]) - ma_params = pm.Normal("ma_params", sigma=1, dims=["ma_lag"]) - sigma_state = pm.Exponential("sigma_state", 0.5) - sigma_obs = pm.Exponential("sigma_obs", 0.1) + pm.Normal("ar_params", sigma=0.1, dims=["ar_lag"]) + pm.Normal("ma_params", sigma=1, dims=["ma_lag"]) + pm.Exponential("sigma_state", 0.5) + pm.Exponential("sigma_obs", 0.1) - arima_mod_interp.build_statespace_graph(data=data, save_kalman_filter_outputs_in_idata=True) + arima_mod_interp.build_statespace_graph( + data=data, save_kalman_filter_outputs_in_idata=True + ) return pymc_mod @@ -236,7 +253,9 @@ def test_harvey_state_names(p, d, q, P, D, Q, S, expected_names): @pytest.mark.parametrize("p,d,q,P,D,Q,S", test_orders) def test_make_SARIMA_transition_matrix(p, d, q, P, D, Q, S): T = make_SARIMA_transition_matrix(p, d, q, P, D, Q, S) - mod = sm.tsa.SARIMAX(np.random.normal(size=100), order=(p, d, q), seasonal_order=(P, D, Q, S)) + mod = sm.tsa.SARIMAX( + np.random.normal(size=100), order=(p, d, q), seasonal_order=(P, D, Q, S) + ) T2 = mod.ssm["transition"] if D > 2: @@ -256,29 +275,46 @@ def test_SARIMAX_update_matches_statsmodels(p, d, q, P, D, Q, S, data, rng): sm_sarimax = sm.tsa.SARIMAX(data, order=(p, d, q), seasonal_order=(P, D, Q, S)) param_names = sm_sarimax.param_names - param_d = {name: getattr(np, floatX)(rng.normal(scale=0.1) ** 2) for name in param_names} + param_d = { + name: getattr(np, floatX)(rng.normal(scale=0.1) ** 2) for name in param_names + } - res = sm_sarimax.fit_constrained(param_d) + sm_sarimax.fit_constrained(param_d) mod = BayesianSARIMA( - order=(p, d, q), seasonal_order=(P, D, Q, S), verbose=False, stationary_initialization=False + order=(p, d, q), + seasonal_order=(P, D, Q, S), + verbose=False, + stationary_initialization=False, ) - with pm.Model() as pm_mod: - x0 = pm.Normal("x0", shape=(mod.k_states,)) - P0 = pm.Deterministic("P0", pt.eye(mod.k_states, dtype=floatX)) + with pm.Model(): + pm.Normal("x0", shape=(mod.k_states,)) + pm.Deterministic("P0", pt.eye(mod.k_states, dtype=floatX)) if q > 0: pm.Deterministic( "ma_params", pt.as_tensor_variable( - np.array([param_d[k] for k in param_d if k.startswith("ma.") and "S." not in k]) + np.array( + [ + param_d[k] + for k in param_d + if k.startswith("ma.") and "S." not in k + ] + ) ), ) if p > 0: pm.Deterministic( "ar_params", pt.as_tensor_variable( - np.array([param_d[k] for k in param_d if k.startswith("ar.") and "S." not in k]) + np.array( + [ + param_d[k] + for k in param_d + if k.startswith("ar.") and "S." not in k + ] + ) ), ) if P > 0: @@ -297,7 +333,9 @@ def test_SARIMAX_update_matches_statsmodels(p, d, q, P, D, Q, S, data, rng): ), ) - pm.Deterministic("sigma_state", pt.as_tensor_variable(np.sqrt(param_d["sigma2"]))) + pm.Deterministic( + "sigma_state", pt.as_tensor_variable(np.sqrt(param_d["sigma2"])) + ) mod._insert_random_variables() matrices = pm.draw(mod.subbed_ssm) @@ -306,7 +344,9 @@ def test_SARIMAX_update_matches_statsmodels(p, d, q, P, D, Q, S, data, rng): for matrix in ["transition", "selection", "state_cov", "obs_cov", "design"]: if matrix == "transition" and D > 2: pytest.skip("Statsmodels has a bug when D > 2, skip this test.)") - assert_allclose(matrix_dict[matrix], sm_sarimax.ssm[matrix], err_msg=f"{matrix} not equal") + assert_allclose( + matrix_dict[matrix], sm_sarimax.ssm[matrix], err_msg=f"{matrix} not equal" + ) @pytest.mark.parametrize("filter_output", ["filtered", "predicted", "smoothed"]) @@ -319,7 +359,8 @@ def test_all_prior_covariances_are_PSD(filter_output, pymc_mod, rng): def test_interpretable_raises_if_d_nonzero(): with pytest.raises( - ValueError, match="Cannot use interpretable state structure with statespace differencing" + ValueError, + match="Cannot use interpretable state structure with statespace differencing", ): BayesianSARIMA( order=(2, 1, 1), @@ -363,7 +404,9 @@ def test_interpretable_states_are_interpretable(arima_mod_interp, pymc_mod_inter ) def test_representations_are_equivalent(p, d, q, P, D, Q, S, data, rng): if (d + D) > 0: - pytest.skip('state_structure = "interpretable" cannot include statespace differences') + pytest.skip( + 'state_structure = "interpretable" cannot include statespace differences' + ) shared_params = make_stationary_params(data, p, d, q, P, D, Q, S) test_values = {} @@ -400,5 +443,7 @@ def test_representations_are_equivalent(p, d, q, P, D, Q, S, data, rng): @pytest.mark.parametrize("order, name", [((4, 1, 0, 0), "AR"), ((0, 0, 4, 1), "MA")]) def test_invalid_order_raises(order, name): p, P, q, Q = order - with pytest.raises(ValueError, match=f"The following {name} and seasonal {name} terms overlap"): + with pytest.raises( + ValueError, match=f"The following {name} and seasonal {name} terms overlap" + ): BayesianSARIMA(order=(p, 0, q), seasonal_order=(P, 0, Q, 4)) diff --git a/tests/statespace/test_VARMAX.py b/tests/statespace/test_VARMAX.py index 43faebe8e..3c3d2506c 100644 --- a/tests/statespace/test_VARMAX.py +++ b/tests/statespace/test_VARMAX.py @@ -11,9 +11,6 @@ from pymc_experimental.statespace import BayesianVARMAX from pymc_experimental.statespace.utils.constants import SHORT_NAME_TO_LONG -from tests.statespace.utilities.shared_fixtures import ( # pylint: disable=unused-import - rng, -) floatX = pytensor.config.floatX ps = [0, 1, 2, 3] @@ -55,15 +52,20 @@ def pymc_mod(varma_mod, data): state_chol, *_ = pm.LKJCholeskyCov( "state_chol", n=varma_mod.k_posdef, eta=1, sd_dist=pm.Exponential.dist(1) ) - ar_params = pm.Normal( - "ar_params", mu=0, sigma=0.1, dims=["observed_state", "ar_lag", "observed_state_aux"] + pm.Normal( + "ar_params", + mu=0, + sigma=0.1, + dims=["observed_state", "ar_lag", "observed_state_aux"], ) - state_cov = pm.Deterministic( + pm.Deterministic( "state_cov", state_chol @ state_chol.T, dims=["shock", "shock_aux"] ) - sigma_obs = pm.Exponential("sigma_obs", 1, dims=["observed_state"]) + pm.Exponential("sigma_obs", 1, dims=["observed_state"]) - varma_mod.build_statespace_graph(data=data, save_kalman_filter_outputs_in_idata=True) + varma_mod.build_statespace_graph( + data=data, save_kalman_filter_outputs_in_idata=True + ) return pymc_mod @@ -111,7 +113,7 @@ def test_VARMAX_update_matches_statsmodels(data, order, rng): for k in param_list } - res = sm_var.fit_constrained(param_d) + sm_var.fit_constrained(param_d) mod = BayesianVARMAX( k_endog=data.shape[1], @@ -124,20 +126,28 @@ def test_VARMAX_update_matches_statsmodels(data, order, rng): ar_shape = (mod.k_endog, mod.p, mod.k_endog) ma_shape = (mod.k_endog, mod.q, mod.k_endog) - with pm.Model() as pm_mod: - x0 = pm.Deterministic("x0", pt.zeros(mod.k_states, dtype=floatX)) - P0 = pm.Deterministic("P0", pt.eye(mod.k_states, dtype=floatX)) - ma_params = pm.Deterministic( + with pm.Model(): + pm.Deterministic("x0", pt.zeros(mod.k_states, dtype=floatX)) + pm.Deterministic("P0", pt.eye(mod.k_states, dtype=floatX)) + pm.Deterministic( "ma_params", - pt.as_tensor_variable(np.array([param_d[var] for var in ma])).reshape(ma_shape), + pt.as_tensor_variable(np.array([param_d[var] for var in ma])).reshape( + ma_shape + ), ) - ar_params = pm.Deterministic( + pm.Deterministic( "ar_params", - pt.as_tensor_variable(np.array([param_d[var] for var in ar])).reshape(ar_shape), + pt.as_tensor_variable(np.array([param_d[var] for var in ar])).reshape( + ar_shape + ), ) state_chol = np.zeros((mod.k_posdef, mod.k_posdef), dtype=floatX) - state_chol[np.tril_indices(mod.k_posdef)] = np.array([param_d[var] for var in state_cov]) - state_cov = pm.Deterministic("state_cov", pt.as_tensor_variable(state_chol @ state_chol.T)) + state_chol[np.tril_indices(mod.k_posdef)] = np.array( + [param_d[var] for var in state_cov] + ) + state_cov = pm.Deterministic( + "state_cov", pt.as_tensor_variable(state_chol @ state_chol.T) + ) mod._insert_random_variables() matrices = pm.draw(mod.subbed_ssm) @@ -161,7 +171,9 @@ def test_all_prior_covariances_are_PSD(filter_output, pymc_mod, rng): {"n_steps": 10, "shock_size": np.array([1.0, 0.0, 0.0])}, { "n_steps": 10, - "shock_cov": np.array([[1.38, 0.58, -1.84], [0.58, 0.99, -0.82], [-1.84, -0.82, 2.51]]), + "shock_cov": np.array( + [[1.38, 0.58, -1.84], [0.58, 0.99, -0.82], [-1.84, -0.82, 2.51]] + ), }, { "shock_trajectory": np.r_[ @@ -172,12 +184,20 @@ def test_all_prior_covariances_are_PSD(filter_output, pymc_mod, rng): }, ] -ids = ["from-posterior-cov", "scalar_shock_size", "array_shock_size", "user-cov", "trajectory"] +ids = [ + "from-posterior-cov", + "scalar_shock_size", + "array_shock_size", + "user-cov", + "trajectory", +] @pytest.mark.parametrize("parameters", parameters, ids=ids) @pytest.mark.skipif(floatX == "float32", reason="Impulse covariance not PSD if float32") def test_impulse_response(parameters, varma_mod, idata, rng): - irf = varma_mod.impulse_response_function(idata.prior, random_seed=rng, **parameters) + irf = varma_mod.impulse_response_function( + idata.prior, random_seed=rng, **parameters + ) assert not np.any(np.isnan(irf.irf.values)) diff --git a/tests/statespace/test_coord_assignment.py b/tests/statespace/test_coord_assignment.py index 58310c4e2..0b5240f85 100644 --- a/tests/statespace/test_coord_assignment.py +++ b/tests/statespace/test_coord_assignment.py @@ -20,7 +20,13 @@ ) from tests.statespace.utilities.test_helpers import load_nile_test_data -function_names = ["pandas_date_freq", "pandas_date_nofreq", "pandas_nodate", "numpy", "pytensor"] +function_names = [ + "pandas_date_freq", + "pandas_date_nofreq", + "pandas_nodate", + "numpy", + "pytensor", +] expected_warning = [ does_not_raise(), pytest.warns(UserWarning, match=NO_FREQ_INFO_WARNING), @@ -78,10 +84,12 @@ def _create_model(f): 1, dims="state", ) - P0 = pm.Deterministic("P0", pt.diag(P0_diag), dims=("state", "state_aux")) - initial_trend = pm.Normal("initial_trend", dims="trend_state") - sigma_trend = pm.Exponential("sigma_trend", 1, dims="trend_shock") - ss_mod.build_statespace_graph(data, save_kalman_filter_outputs_in_idata=True) + pm.Deterministic("P0", pt.diag(P0_diag), dims=("state", "state_aux")) + pm.Normal("initial_trend", dims="trend_state") + pm.Exponential("sigma_trend", 1, dims="trend_shock") + ss_mod.build_statespace_graph( + data, save_kalman_filter_outputs_in_idata=True + ) return mod return _create_model @@ -92,7 +100,9 @@ def test_filter_output_coord_assignment(f, warning, create_model): with warning: pymc_model = create_model(f) - for output in FILTER_OUTPUT_NAMES + SMOOTHER_OUTPUT_NAMES + ["predicted_observed_state"]: + for output in ( + FILTER_OUTPUT_NAMES + SMOOTHER_OUTPUT_NAMES + ["predicted_observed_state"] + ): assert pymc_model.named_vars_to_dims[output] == FILTER_OUTPUT_DIMS[output] @@ -101,9 +111,9 @@ def test_model_build_without_coords(load_dataset): data = load_dataset("numpy") with pm.Model() as mod: P0_diag = pm.Exponential("P0_diag", 1, shape=(2,)) - P0 = pm.Deterministic("P0", pt.diag(P0_diag)) - initial_trend = pm.Normal("initial_trend", shape=(2,)) - sigma_trend = pm.Exponential("sigma_trend", 1, shape=(2,)) + pm.Deterministic("P0", pt.diag(P0_diag)) + pm.Normal("initial_trend", shape=(2,)) + pm.Exponential("sigma_trend", 1, shape=(2,)) ss_mod.build_statespace_graph(data, register_data=False) assert mod.coords == {} diff --git a/tests/statespace/test_distributions.py b/tests/statespace/test_distributions.py index 1d049ae92..1124b8e54 100644 --- a/tests/statespace/test_distributions.py +++ b/tests/statespace/test_distributions.py @@ -17,9 +17,6 @@ OBS_STATE_DIM, TIME_DIM, ) -from tests.statespace.utilities.shared_fixtures import ( # pylint: disable=unused-import - rng, -) from tests.statespace.utilities.test_helpers import ( delete_rvs_from_model, fast_eval, @@ -52,9 +49,9 @@ def pymc_model(data): with pm.Model() as mod: data = pm.Data("data", data.values) P0_diag = pm.Exponential("P0_diag", 1, shape=(2,)) - P0 = pm.Deterministic("P0", pt.diag(P0_diag)) - initial_trend = pm.Normal("initial_trend", shape=(2,)) - sigma_trend = pm.Exponential("sigma_trend", 1, shape=(2,)) + pm.Deterministic("P0", pt.diag(P0_diag)) + pm.Normal("initial_trend", shape=(2,)) + pm.Exponential("sigma_trend", 1, shape=(2,)) return mod @@ -69,10 +66,10 @@ def pymc_model_2(data): with pm.Model(coords=coords) as mod: P0_diag = pm.Exponential("P0_diag", 1, shape=(2,)) - P0 = pm.Deterministic("P0", pt.diag(P0_diag)) - initial_trend = pm.Normal("initial_trend", shape=(2,)) - sigma_trend = pm.Exponential("sigma_trend", 1, shape=(2,)) - sigma_me = pm.Exponential("sigma_error", 1) + pm.Deterministic("P0", pt.diag(P0_diag)) + pm.Normal("initial_trend", shape=(2,)) + pm.Exponential("sigma_trend", 1, shape=(2,)) + pm.Exponential("sigma_error", 1) return mod @@ -105,7 +102,9 @@ def test_loglike_vectors_agree(kfilter, pymc_model): matrices = ss_mod.unpack_statespace() filter_outputs = ss_mod.kalman_filter.build_graph(pymc_model["data"], *matrices) - filter_mus, pred_mus, obs_mu, filter_covs, pred_covs, obs_cov, ll = filter_outputs + filter_mus, pred_mus, obs_mu, filter_covs, pred_covs, obs_cov, ll = ( + filter_outputs + ) test_ll = fast_eval(ll) @@ -151,7 +150,9 @@ def test_lgss_distribution_from_steps(output_name, ss_mod_me, pymc_model_2): matrices = ss_mod_me.unpack_statespace() # pylint: disable=unpacking-non-sequence - latent_states, obs_states = LinearGaussianStateSpace("states", *matrices, steps=100) + latent_states, obs_states = LinearGaussianStateSpace( + "states", *matrices, steps=100 + ) # pylint: enable=unpacking-non-sequence idata = pm.sample_prior_predictive(draws=10) @@ -174,7 +175,7 @@ def test_lgss_distribution_with_dims(output_name, ss_mod_me, pymc_model_2): steps=100, dims=[TIME_DIM, ALL_STATE_DIM, OBS_STATE_DIM], sequence_names=[], - k_endog=ss_mod_me.k_endog + k_endog=ss_mod_me.k_endog, ) # pylint: enable=unpacking-non-sequence idata = pm.sample_prior_predictive(draws=10) @@ -182,10 +183,16 @@ def test_lgss_distribution_with_dims(output_name, ss_mod_me, pymc_model_2): assert idata.prior.coords["time"].shape == (101,) assert all( - [dim in idata.prior.states_latent.coords.keys() for dim in [TIME_DIM, ALL_STATE_DIM]] + [ + dim in idata.prior.states_latent.coords.keys() + for dim in [TIME_DIM, ALL_STATE_DIM] + ] ) assert all( - [dim in idata.prior.states_observed.coords.keys() for dim in [TIME_DIM, OBS_STATE_DIM]] + [ + dim in idata.prior.states_observed.coords.keys() + for dim in [TIME_DIM, OBS_STATE_DIM] + ] ) assert not np.any(np.isnan(idata.prior[output_name].values)) @@ -205,12 +212,12 @@ def test_lgss_with_time_varying_inputs(output_name, rng): } with pm.Model(coords=coords): - exog_data = pm.Data("data_exog", X) + pm.Data("data_exog", X) P0_diag = pm.Exponential("P0_diag", 1, shape=(mod.k_states,)) - P0 = pm.Deterministic("P0", pt.diag(P0_diag)) - initial_trend = pm.Normal("initial_trend", shape=(2,)) - sigma_trend = pm.Exponential("sigma_trend", 1, shape=(2,)) - beta_exog = pm.Normal("beta_exog", shape=(3,)) + pm.Deterministic("P0", pt.diag(P0_diag)) + pm.Normal("initial_trend", shape=(2,)) + pm.Exponential("sigma_trend", 1, shape=(2,)) + pm.Normal("beta_exog", shape=(3,)) mod._insert_random_variables() mod._insert_data_variables() @@ -222,17 +229,23 @@ def test_lgss_with_time_varying_inputs(output_name, rng): *matrices, steps=9, sequence_names=["d", "Z"], - dims=[TIME_DIM, ALL_STATE_DIM, OBS_STATE_DIM] + dims=[TIME_DIM, ALL_STATE_DIM, OBS_STATE_DIM], ) # pylint: enable=unpacking-non-sequence idata = pm.sample_prior_predictive(draws=10) assert idata.prior.coords["time"].shape == (10,) assert all( - [dim in idata.prior.states_latent.coords.keys() for dim in [TIME_DIM, ALL_STATE_DIM]] + [ + dim in idata.prior.states_latent.coords.keys() + for dim in [TIME_DIM, ALL_STATE_DIM] + ] ) assert all( - [dim in idata.prior.states_observed.coords.keys() for dim in [TIME_DIM, OBS_STATE_DIM]] + [ + dim in idata.prior.states_observed.coords.keys() + for dim in [TIME_DIM, OBS_STATE_DIM] + ] ) assert not np.any(np.isnan(idata.prior[output_name].values)) diff --git a/tests/statespace/test_kalman_filter.py b/tests/statespace/test_kalman_filter.py index 15d1effa5..55432724c 100644 --- a/tests/statespace/test_kalman_filter.py +++ b/tests/statespace/test_kalman_filter.py @@ -137,7 +137,9 @@ def f_standard_nd(): ll_obs, ) = StandardFilter().build_graph(*inputs) - smoothed_states, smoothed_covs = ksmoother.build_graph(T, R, Q, filtered_states, filtered_covs) + smoothed_states, smoothed_covs = ksmoother.build_graph( + T, R, Q, filtered_states, filtered_covs + ) outputs = [ filtered_states, @@ -234,7 +236,9 @@ def test_missing_data(filter_func, filter_name, p, rng): @pytest.mark.parametrize("filter_func", filter_funcs, ids=filter_names) -@pytest.mark.parametrize("output_idx", [(0, 2), (3, 5)], ids=["smoothed_states", "smoothed_covs"]) +@pytest.mark.parametrize( + "output_idx", [(0, 2), (3, 5)], ids=["smoothed_states", "smoothed_covs"] +) def test_last_smoother_is_last_filtered(filter_func, output_idx, rng): p, m, r, n = 1, 5, 1, 10 inputs = make_test_inputs(p, m, r, n, rng) @@ -288,16 +292,24 @@ def test_filters_match_statsmodel_output(filter_func, n_missing, rng): @pytest.mark.parametrize( - "filter_func, filter_name", zip(filter_funcs[:-1], filter_names[:-1]), ids=filter_names[:-1] + "filter_func, filter_name", + zip(filter_funcs[:-1], filter_names[:-1]), + ids=filter_names[:-1], ) @pytest.mark.parametrize("n_missing", [0, 5], ids=["n_missing=0", "n_missing=5"]) @pytest.mark.parametrize("obs_noise", [True, False]) -def test_all_covariance_matrices_are_PSD(filter_func, filter_name, n_missing, obs_noise, rng): +def test_all_covariance_matrices_are_PSD( + filter_func, filter_name, n_missing, obs_noise, rng +): if (floatX == "float32") & (filter_name == "UnivariateFilter"): # TODO: These tests all pass locally for me with float32 but they fail on the CI, so i'm just disabling them. - pytest.skip("Univariate filter not stable at half precision without measurement error") + pytest.skip( + "Univariate filter not stable at half precision without measurement error" + ) - fit_sm_mod, [data, a0, P0, c, d, T, Z, R, H, Q] = nile_test_test_helper(rng, n_missing) + fit_sm_mod, [data, a0, P0, c, d, T, Z, R, H, Q] = nile_test_test_helper( + rng, n_missing + ) H *= int(obs_noise) inputs = [data, a0, P0, c, d, T, Z, R, H, Q] @@ -307,7 +319,9 @@ def test_all_covariance_matrices_are_PSD(filter_func, filter_name, n_missing, ob cov_stack = outputs[output_idx] w, v = np.linalg.eig(cov_stack) - assert_array_less(0, w, err_msg=f"Smallest eigenvalue of {name}: {min(w.ravel())}") + assert_array_less( + 0, w, err_msg=f"Smallest eigenvalue of {name}: {min(w.ravel())}" + ) assert_allclose( cov_stack, np.swapaxes(cov_stack, -2, -1), @@ -348,4 +362,6 @@ def test_kalman_filter_jax(filter): pt_outputs = f_pt(*inputs_np) for name, jax_res, pt_res in zip(output_names, jax_outputs, pt_outputs): - assert_allclose(jax_res, pt_res, atol=ATOL, rtol=RTOL, err_msg=f"{name} failed!") + assert_allclose( + jax_res, pt_res, atol=ATOL, rtol=RTOL, err_msg=f"{name} failed!" + ) diff --git a/tests/statespace/test_representation.py b/tests/statespace/test_representation.py index 10388d94d..ad088cad0 100644 --- a/tests/statespace/test_representation.py +++ b/tests/statespace/test_representation.py @@ -123,28 +123,34 @@ def test_assign_time_varying_matrices(self): assert_allclose(fast_eval(ssm["design"][0, 0]), 3.0, atol=atol) assert_allclose(fast_eval(ssm["transition"][0, :]), 2.7, atol=atol) assert_allclose(fast_eval(ssm["selection"][-1, -1]), 9.9, atol=atol) - assert_allclose(fast_eval(ssm["state_intercept"][:, 0]), np.arange(n), atol=atol) + assert_allclose( + fast_eval(ssm["state_intercept"][:, 0]), np.arange(n), atol=atol + ) def test_invalid_key_name_raises(self): ssm = PytensorRepresentation(k_endog=3, k_states=5, k_posdef=1) with self.assertRaises(IndexError) as e: - X = ssm["invalid_key"] + ssm["invalid_key"] msg = str(e.exception) self.assertEqual(msg, "invalid_key is an invalid state space matrix name") def test_non_string_key_raises(self): ssm = PytensorRepresentation(k_endog=3, k_states=5, k_posdef=1) with self.assertRaises(IndexError) as e: - X = ssm[0] + ssm[0] msg = str(e.exception) - self.assertEqual(msg, "First index must the name of a valid state space matrix.") + self.assertEqual( + msg, "First index must the name of a valid state space matrix." + ) def test_invalid_key_tuple_raises(self): ssm = PytensorRepresentation(k_endog=3, k_states=5, k_posdef=1) with self.assertRaises(IndexError) as e: - X = ssm[0, 1, 1] + ssm[0, 1, 1] msg = str(e.exception) - self.assertEqual(msg, "First index must the name of a valid state space matrix.") + self.assertEqual( + msg, "First index must the name of a valid state space matrix." + ) def test_slice_statespace_matrix(self): T = np.eye(5) diff --git a/tests/statespace/test_statespace.py b/tests/statespace/test_statespace.py index d93c66062..ec7e7ddd3 100644 --- a/tests/statespace/test_statespace.py +++ b/tests/statespace/test_statespace.py @@ -77,7 +77,11 @@ def make_symbolic_graph(self): P0 = np.eye(2, dtype=floatX) * 1e6 ss_mod = StateSpace( - k_endog=nile.shape[1], k_states=2, k_posdef=1, filter_type="standard", verbose=False + k_endog=nile.shape[1], + k_states=2, + k_posdef=1, + filter_type="standard", + verbose=False, ) for X, name in zip( [Z, R, H, Q, P0], @@ -92,9 +96,11 @@ def make_symbolic_graph(self): def pymc_mod(ss_mod): with pm.Model(coords=ss_mod.coords) as pymc_mod: rho = pm.Beta("rho", 1, 1) - zeta = pm.Deterministic("zeta", 1 - rho) + pm.Deterministic("zeta", 1 - rho) - ss_mod.build_statespace_graph(data=nile, save_kalman_filter_outputs_in_idata=True) + ss_mod.build_statespace_graph( + data=nile, save_kalman_filter_outputs_in_idata=True + ) names = ["x0", "P0", "c", "d", "T", "Z", "R", "H", "Q"] for name, matrix in zip(names, ss_mod.unpack_statespace()): pm.Deterministic(name, matrix) @@ -117,15 +123,15 @@ def exog_pymc_mod(exog_ss_mod, rng): X = rng.normal(size=(100, 3)).astype(floatX) with pm.Model(coords=exog_ss_mod.coords) as m: - exog_data = pm.Data("data_exog", X) - initial_trend = pm.Normal("initial_trend", dims=["trend_state"]) + pm.Data("data_exog", X) + pm.Normal("initial_trend", dims=["trend_state"]) P0_sigma = pm.Exponential("P0_sigma", 1) - P0 = pm.Deterministic( + pm.Deterministic( "P0", pt.eye(exog_ss_mod.k_states) * P0_sigma, dims=["state", "state_aux"] ) - beta_exog = pm.Normal("beta_exog", dims=["exog_state"]) + pm.Normal("beta_exog", dims=["exog_state"]) - sigma_trend = pm.Exponential("sigma_trend", 1, dims=["trend_shock"]) + pm.Exponential("sigma_trend", 1, dims=["trend_shock"]) exog_ss_mod.build_statespace_graph(y, save_kalman_filter_outputs_in_idata=True) return m @@ -151,15 +157,19 @@ def idata_exog(exog_pymc_mod, rng): def test_invalid_filter_name_raises(): - msg = "The following are valid filter types: " + ", ".join(list(FILTER_FACTORY.keys())) + msg = "The following are valid filter types: " + ", ".join( + list(FILTER_FACTORY.keys()) + ) with pytest.raises(NotImplementedError, match=msg): - mod = make_statespace_mod(k_endog=1, k_states=5, k_posdef=1, filter_type="invalid_filter") + make_statespace_mod( + k_endog=1, k_states=5, k_posdef=1, filter_type="invalid_filter" + ) def test_singleseriesfilter_raises_if_k_endog_gt_one(): msg = 'Cannot use filter_type = "single" with multiple observed time series' with pytest.raises(ValueError, match=msg): - mod = make_statespace_mod(k_endog=10, k_states=5, k_posdef=1, filter_type="single") + make_statespace_mod(k_endog=10, k_states=5, k_posdef=1, filter_type="single") def test_unpack_before_insert_raises(rng): @@ -171,7 +181,7 @@ def test_unpack_before_insert_raises(rng): msg = "Cannot unpack the complete statespace system until PyMC model variables have been inserted." with pytest.raises(ValueError, match=msg): - outputs = mod.unpack_statespace() + mod.unpack_statespace() def test_unpack_matrices(rng): @@ -194,19 +204,19 @@ def test_param_names_raises_on_base_class(): k_endog=1, k_states=5, k_posdef=1, filter_type="standard", verbose=False ) with pytest.raises(NotImplementedError): - x = mod.param_names + mod.param_names def test_base_class_raises(): with pytest.raises(NotImplementedError): - mod = PyMCStateSpace( + PyMCStateSpace( k_endog=1, k_states=5, k_posdef=1, filter_type="standard", verbose=False ) def test_update_raises_if_missing_variables(ss_mod): - with pm.Model() as mod: - rho = pm.Normal("rho") + with pm.Model(): + pm.Normal("rho") msg = "The following required model parameters were not found in the PyMC model: zeta" with pytest.raises(ValueError, match=msg): ss_mod._insert_random_variables() @@ -216,9 +226,9 @@ def test_build_statespace_graph_warns_if_data_has_nans(): # Breaks tests if it uses the session fixtures because we can't call build_statespace_graph over and over ss_mod = st.LevelTrendComponent(order=1, innovations_order=0).build(verbose=False) - with pm.Model() as pymc_mod: - initial_trend = pm.Normal("initial_trend", shape=(1,)) - P0 = pm.Deterministic("P0", pt.eye(1, dtype=floatX)) + with pm.Model(): + pm.Normal("initial_trend", shape=(1,)) + pm.Deterministic("P0", pt.eye(1, dtype=floatX)) with pytest.warns(pm.ImputationWarning): ss_mod.build_statespace_graph( data=np.full((10, 1), np.nan, dtype=floatX), register_data=False @@ -229,13 +239,15 @@ def test_build_statespace_graph_raises_if_data_has_missing_fill(): # Breaks tests if it uses the session fixtures because we can't call build_statespace_graph over and over ss_mod = st.LevelTrendComponent(order=1, innovations_order=0).build(verbose=False) - with pm.Model() as pymc_mod: - initial_trend = pm.Normal("initial_trend", shape=(1,)) - P0 = pm.Deterministic("P0", pt.eye(1, dtype=floatX)) + with pm.Model(): + pm.Normal("initial_trend", shape=(1,)) + pm.Deterministic("P0", pt.eye(1, dtype=floatX)) with pytest.raises(ValueError, match="Provided data contains the value 1.0"): data = np.ones((10, 1), dtype=floatX) data[3] = np.nan - ss_mod.build_statespace_graph(data=data, missing_fill_value=1.0, register_data=False) + ss_mod.build_statespace_graph( + data=data, missing_fill_value=1.0, register_data=False + ) def test_build_statespace_graph(pymc_mod): @@ -281,12 +293,21 @@ def test_sampling_methods(group, kind, ss_mod, idata, rng): def test_forecast(filter_output, ss_mod, idata, rng): time_idx = idata.posterior.coords["time"].values forecast_idata = ss_mod.forecast( - idata, start=time_idx[-1], periods=10, filter_output=filter_output, random_seed=rng + idata, + start=time_idx[-1], + periods=10, + filter_output=filter_output, + random_seed=rng, ) assert forecast_idata.coords["time"].values.shape == (10,) assert forecast_idata.forecast_latent.dims == ("chain", "draw", "time", "state") - assert forecast_idata.forecast_observed.dims == ("chain", "draw", "time", "observed_state") + assert forecast_idata.forecast_observed.dims == ( + "chain", + "draw", + "time", + "observed_state", + ) assert not np.any(np.isnan(forecast_idata.forecast_latent.values)) assert not np.any(np.isnan(forecast_idata.forecast_observed.values)) @@ -295,7 +316,9 @@ def test_forecast(filter_output, ss_mod, idata, rng): @pytest.mark.filterwarnings("ignore:No time index found on the supplied data.") def test_forecast_fails_if_exog_needed(exog_ss_mod, idata_exog): time_idx = idata_exog.observed_data.coords["time"].values - with pytest.xfail("Scenario-based forcasting with exogenous variables not currently supported"): - forecast_idata = exog_ss_mod.forecast( + with pytest.xfail( + "Scenario-based forcasting with exogenous variables not currently supported" + ): + exog_ss_mod.forecast( idata_exog, start=time_idx[-1], periods=10, random_seed=rng ) diff --git a/tests/statespace/test_statespace_JAX.py b/tests/statespace/test_statespace_JAX.py index d9a0c4f96..cc4e1b00f 100644 --- a/tests/statespace/test_statespace_JAX.py +++ b/tests/statespace/test_statespace_JAX.py @@ -12,13 +12,6 @@ MATRIX_NAMES, SMOOTHER_OUTPUT_NAMES, ) -from tests.statespace.test_statespace import ( # pylint: disable=unused-import - exog_ss_mod, - ss_mod, -) -from tests.statespace.utilities.shared_fixtures import ( # pylint: disable=unused-import - rng, -) from tests.statespace.utilities.test_helpers import load_nile_test_data pytest.importorskip("jax") @@ -34,7 +27,7 @@ def pymc_mod(ss_mod): with pm.Model(coords=ss_mod.coords) as pymc_mod: rho = pm.Beta("rho", 1, 1) - zeta = pm.Deterministic("zeta", 1 - rho) + pm.Deterministic("zeta", 1 - rho) ss_mod.build_statespace_graph( data=nile, mode="JAX", save_kalman_filter_outputs_in_idata=True @@ -52,15 +45,15 @@ def exog_pymc_mod(exog_ss_mod, rng): X = rng.normal(size=(100, 3)).astype(floatX) with pm.Model(coords=exog_ss_mod.coords) as m: - exog_data = pm.Data("data_exog", X) - initial_trend = pm.Normal("initial_trend", dims=["trend_state"]) + pm.Data("data_exog", X) + pm.Normal("initial_trend", dims=["trend_state"]) P0_sigma = pm.Exponential("P0_sigma", 1) - P0 = pm.Deterministic( + pm.Deterministic( "P0", pt.eye(exog_ss_mod.k_states) * P0_sigma, dims=["state", "state_aux"] ) - beta_exog = pm.Normal("beta_exog", dims=["exog_state"]) + pm.Normal("beta_exog", dims=["exog_state"]) - sigma_trend = pm.Exponential("sigma_trend", 1, dims=["trend_shock"]) + pm.Exponential("sigma_trend", 1, dims=["trend_shock"]) exog_ss_mod.build_statespace_graph(y, mode="JAX") return m @@ -144,12 +137,21 @@ def test_forecast(filter_output, ss_mod, idata, rng): with pytest.warns(UserWarning, match="The RandomType SharedVariables"): forecast_idata = ss_mod.forecast( - idata, start=time_idx[-1], periods=10, filter_output=filter_output, random_seed=rng + idata, + start=time_idx[-1], + periods=10, + filter_output=filter_output, + random_seed=rng, ) assert forecast_idata.coords["time"].values.shape == (10,) assert forecast_idata.forecast_latent.dims == ("chain", "draw", "time", "state") - assert forecast_idata.forecast_observed.dims == ("chain", "draw", "time", "observed_state") + assert forecast_idata.forecast_observed.dims == ( + "chain", + "draw", + "time", + "observed_state", + ) assert not np.any(np.isnan(forecast_idata.forecast_latent.values)) assert not np.any(np.isnan(forecast_idata.forecast_observed.values)) diff --git a/tests/statespace/test_structural.py b/tests/statespace/test_structural.py index 30a037811..413ddab00 100644 --- a/tests/statespace/test_structural.py +++ b/tests/statespace/test_structural.py @@ -24,9 +24,6 @@ SHOCK_DIM, SHORT_NAME_TO_LONG, ) -from tests.statespace.utilities.shared_fixtures import ( # pylint: disable=unused-import - rng, -) from tests.statespace.utilities.test_helpers import ( assert_pattern_repeats, simulate_from_numpy_model, @@ -82,7 +79,9 @@ def _assert_coord_shapes_match_matrices(mod, params): assert c.shape[-1:] == ( n_states, ), f"c expected to have shape (n_states, ), found {c.shape[-1:]}" - assert d.shape[-1:] == (n_obs,), f"d expected to have shape (n_obs, ), found {d.shape[-1:]}" + assert d.shape[-1:] == ( + n_obs, + ), f"d expected to have shape (n_obs, ), found {d.shape[-1:]}" assert T.shape[-2:] == ( n_states, n_states, @@ -118,7 +117,9 @@ def _assert_keys_match(test_dict, expected_dict): expected_keys = list(expected_dict.keys()) param_keys = list(test_dict.keys()) key_diff = set(expected_keys) - set(param_keys) - assert len(key_diff) == 0, f'{", ".join(key_diff)} were not found in the test_dict keys.' + assert ( + len(key_diff) == 0 + ), f'{", ".join(key_diff)} were not found in the test_dict keys.' key_diff = set(param_keys) - set(expected_keys) assert ( @@ -300,10 +301,12 @@ def create_structural_model_and_equivalent_statsmodel( sm_params["sigma2.level"] = sigma if stochastic_trend: sigma = sigma_level_value.pop(0) - sm_params[f"sigma2.trend"] = sigma + sm_params["sigma2.trend"] = sigma comp = st.LevelTrendComponent( - name="level", order=level_trend_order, innovations_order=level_trend_innov_order + name="level", + order=level_trend_order, + innovations_order=level_trend_innov_order, ) components.append(comp) @@ -318,7 +321,8 @@ def create_structural_model_and_equivalent_statsmodel( expected_coords[ALL_STATE_AUX_DIM] += state_names seasonal_dict = { - "seasonal" if i == 0 else f"seasonal.L{i}": c for i, c in enumerate(seasonal_coefs) + "seasonal" if i == 0 else f"seasonal.L{i}": c + for i, c in enumerate(seasonal_coefs) } sm_init.update(seasonal_dict) @@ -345,7 +349,9 @@ def create_structural_model_and_equivalent_statsmodel( s = d["period"] last_state_not_identified = (s / n) == 2.0 n_states = 2 * n - int(last_state_not_identified) - state_names = [f"seasonal_{s}_{f}_{i}" for i in range(n) for f in ["Cos", "Sin"]] + state_names = [ + f"seasonal_{s}_{f}_{i}" for i in range(n) for f in ["Cos", "Sin"] + ] seasonal_params = rng.normal(size=n_states).astype(floatX) @@ -354,7 +360,9 @@ def create_structural_model_and_equivalent_statsmodel( expected_coords[ALL_STATE_DIM] += state_names expected_coords[ALL_STATE_AUX_DIM] += state_names expected_coords[f"seasonal_{s}_state"] += ( - tuple(state_names[:-1]) if last_state_not_identified else tuple(state_names) + tuple(state_names[:-1]) + if last_state_not_identified + else tuple(state_names) ) for param in seasonal_params: @@ -472,7 +480,9 @@ def create_structural_model_and_equivalent_statsmodel( ], ) @pytest.mark.parametrize("autoregressive", [None, 3]) -@pytest.mark.parametrize("seasonal, stochastic_seasonal", [(None, False), (12, False), (12, True)]) +@pytest.mark.parametrize( + "seasonal, stochastic_seasonal", [(None, False), (12, False), (12, True)] +) @pytest.mark.parametrize( "freq_seasonal, stochastic_freq_seasonal", [ @@ -485,8 +495,12 @@ def create_structural_model_and_equivalent_statsmodel( "cycle, damped_cycle, stochastic_cycle", [(False, False, False), (True, False, True), (True, True, True)], ) -@pytest.mark.filterwarnings("ignore::statsmodels.tools.sm_exceptions.ConvergenceWarning") -@pytest.mark.filterwarnings("ignore::statsmodels.tools.sm_exceptions.SpecificationWarning") +@pytest.mark.filterwarnings( + "ignore::statsmodels.tools.sm_exceptions.ConvergenceWarning" +) +@pytest.mark.filterwarnings( + "ignore::statsmodels.tools.sm_exceptions.SpecificationWarning" +) def test_structural_model_against_statsmodels( level, trend, @@ -526,7 +540,11 @@ def test_structural_model_against_statsmodels( if len(sm_init) > 0: init_array = np.concatenate( - [np.atleast_1d(sm_init[k]).ravel() for k in sm_mod.state_names if k != "dummy"] + [ + np.atleast_1d(sm_init[k]).ravel() + for k in sm_mod.state_names + if k != "dummy" + ] ) sm_mod.initialize_known(init_array, np.eye(sm_mod.k_states)) else: @@ -545,7 +563,9 @@ def test_structural_model_against_statsmodels( _assert_coord_shapes_match_matrices(built_model, params) _assert_param_dims_correct(built_model.param_dims, expected_dims) _assert_coords_correct(built_model.coords, expected_coords) - _assert_params_info_correct(built_model.param_info, built_model.coords, built_model.param_dims) + _assert_params_info_correct( + built_model.param_info, built_model.coords, built_model.param_dims + ) def test_level_trend_model(rng): @@ -664,9 +684,16 @@ def test_cycle_component_deterministic(rng): def test_cycle_component_with_dampening(rng): cycle = st.CycleComponent( - name="cycle", cycle_length=12, estimate_cycle_length=False, innovations=False, dampen=True + name="cycle", + cycle_length=12, + estimate_cycle_length=False, + innovations=False, + dampen=True, ) - params = {"cycle": np.array([10.0, 10.0], dtype=floatX), "cycle_dampening_factor": 0.75} + params = { + "cycle": np.array([10.0, 10.0], dtype=floatX), + "cycle_dampening_factor": 0.75, + } x, y = simulate_from_numpy_model(cycle, rng, params, steps=100) # Check that the cycle dampens to zero over time @@ -734,15 +761,21 @@ def test_add_components(): all_params = ll_params.copy() all_params.update(se_params) - (ll_x0, ll_P0, ll_c, ll_d, ll_T, ll_Z, ll_R, ll_H, ll_Q) = unpack_symbolic_matrices_with_params( - ll, ll_params + (ll_x0, ll_P0, ll_c, ll_d, ll_T, ll_Z, ll_R, ll_H, ll_Q) = ( + unpack_symbolic_matrices_with_params(ll, ll_params) ) - (se_x0, se_P0, se_c, se_d, se_T, se_Z, se_R, se_H, se_Q) = unpack_symbolic_matrices_with_params( - se, se_params + (se_x0, se_P0, se_c, se_d, se_T, se_Z, se_R, se_H, se_Q) = ( + unpack_symbolic_matrices_with_params(se, se_params) ) x0, P0, c, d, T, Z, R, H, Q = unpack_symbolic_matrices_with_params(mod, all_params) - for property in ["param_names", "shock_names", "param_info", "coords", "param_dims"]: + for property in [ + "param_names", + "shock_names", + "param_info", + "coords", + "param_dims", + ]: assert [x in getattr(mod, property) for x in getattr(ll, property)] assert [x in getattr(mod, property) for x in getattr(se, property)] @@ -751,7 +784,9 @@ def test_add_components(): all_mats = [T, R, Q] for ll_mat, se_mat, all_mat in zip(ll_mats, se_mats, all_mats): - assert_allclose(all_mat, linalg.block_diag(ll_mat, se_mat), atol=ATOL, rtol=RTOL) + assert_allclose( + all_mat, linalg.block_diag(ll_mat, se_mat), atol=ATOL, rtol=RTOL + ) ll_mats = [ll_x0, ll_c, ll_Z] se_mats = [se_x0, se_c, se_Z] @@ -759,7 +794,9 @@ def test_add_components(): axes = [0, 0, 1] for ll_mat, se_mat, all_mat, axis in zip(ll_mats, se_mats, all_mats, axes): - assert_allclose(all_mat, np.concatenate([ll_mat, se_mat], axis=axis), atol=ATOL, rtol=RTOL) + assert_allclose( + all_mat, np.concatenate([ll_mat, se_mat], axis=axis), atol=ATOL, rtol=RTOL + ) def test_filter_scans_time_varying_design_matrix(rng): @@ -771,12 +808,12 @@ def test_filter_scans_time_varying_design_matrix(rng): reg = st.RegressionComponent(state_names=["a", "b"], name="exog") mod = reg.build(verbose=False) - with pm.Model(coords=mod.coords) as m: - data_exog = pm.Data("data_exog", data.values) + with pm.Model(coords=mod.coords): + pm.Data("data_exog", data.values) x0 = pm.Normal("x0", dims=["state"]) P0 = pm.Deterministic("P0", pt.eye(mod.k_states), dims=["state", "state_aux"]) - beta_exog = pm.Normal("beta_exog", dims=["exog_state"]) + pm.Normal("beta_exog", dims=["exog_state"]) mod.build_statespace_graph(y) x0, P0, c, d, T, Z, R, H, Q = mod.unpack_statespace() @@ -789,7 +826,9 @@ def test_filter_scans_time_varying_design_matrix(rng): assert_allclose(prior_Z[0, :, :, 0, :], data.values[None].repeat(10, axis=0)) -@pytest.mark.skipif(floatX.endswith("32"), reason="Prior covariance not PSD at half-precision") +@pytest.mark.skipif( + floatX.endswith("32"), reason="Prior covariance not PSD at half-precision" +) def test_extract_components_from_idata(rng): time_idx = pd.date_range(start="2000-01-01", freq="D", periods=100) data = pd.DataFrame(rng.normal(size=(100, 2)), columns=["a", "b"], index=time_idx) @@ -797,21 +836,23 @@ def test_extract_components_from_idata(rng): y = pd.DataFrame(rng.normal(size=(100, 1)), columns=["data"], index=time_idx) ll = st.LevelTrendComponent() - season = st.FrequencySeasonality(name="seasonal", season_length=12, n=2, innovations=False) + season = st.FrequencySeasonality( + name="seasonal", season_length=12, n=2, innovations=False + ) reg = st.RegressionComponent(state_names=["a", "b"], name="exog") me = st.MeasurementError("obs") mod = (ll + season + reg + me).build(verbose=False) - with pm.Model(coords=mod.coords) as m: - data_exog = pm.Data("data_exog", data.values) + with pm.Model(coords=mod.coords): + pm.Data("data_exog", data.values) x0 = pm.Normal("x0", dims=["state"]) P0 = pm.Deterministic("P0", pt.eye(mod.k_states), dims=["state", "state_aux"]) - beta_exog = pm.Normal("beta_exog", dims=["exog_state"]) - initial_trend = pm.Normal("initial_trend", dims=["trend_state"]) - sigma_trend = pm.Exponential("sigma_trend", 1, dims=["trend_shock"]) - seasonal_coefs = pm.Normal("seasonal", dims=["seasonal_state"]) - sigma_obs = pm.Exponential("sigma_obs", 1) + pm.Normal("beta_exog", dims=["exog_state"]) + pm.Normal("initial_trend", dims=["trend_state"]) + pm.Exponential("sigma_trend", 1, dims=["trend_shock"]) + pm.Normal("seasonal", dims=["seasonal_state"]) + pm.Exponential("sigma_obs", 1) mod.build_statespace_graph(y) @@ -821,7 +862,13 @@ def test_extract_components_from_idata(rng): filter_prior = mod.sample_conditional_prior(prior) comp_prior = mod.extract_components_from_idata(filter_prior) comp_states = comp_prior.filtered_prior.coords["state"].values - expected_states = ["LevelTrend[level]", "LevelTrend[trend]", "seasonal", "exog[a]", "exog[b]"] + expected_states = [ + "LevelTrend[level]", + "LevelTrend[trend]", + "seasonal", + "exog[a]", + "exog[b]", + ] missing = set(comp_states) - set(expected_states) assert len(missing) == 0, missing diff --git a/tests/statespace/utilities/test_helpers.py b/tests/statespace/utilities/test_helpers.py index 7f2183c14..15076db26 100644 --- a/tests/statespace/utilities/test_helpers.py +++ b/tests/statespace/utilities/test_helpers.py @@ -58,7 +58,9 @@ def initialize_filter(kfilter, mode=None): ll_obs, ) = kfilter.build_graph(*inputs, mode=mode) - smoothed_states, smoothed_covs = ksmoother.build_graph(T, R, Q, filtered_states, filtered_covs) + smoothed_states, smoothed_covs = ksmoother.build_graph( + T, R, Q, filtered_states, filtered_covs + ) outputs = [ filtered_states, @@ -210,7 +212,9 @@ def unpack_statespace(ssm): return [ssm[SHORT_NAME_TO_LONG[x]] for x in MATRIX_NAMES] -def unpack_symbolic_matrices_with_params(mod, param_dict, data_dict=None, mode="FAST_COMPILE"): +def unpack_symbolic_matrices_with_params( + mod, param_dict, data_dict=None, mode="FAST_COMPILE" +): inputs = list(mod._name_to_variable.values()) if data_dict is not None: inputs += list(mod._name_to_data.values()) @@ -233,7 +237,9 @@ def simulate_from_numpy_model(mod, rng, param_dict, data_dict=None, steps=100): """ Helper function to visualize the components outside of a PyMC model context """ - x0, P0, c, d, T, Z, R, H, Q = unpack_symbolic_matrices_with_params(mod, param_dict, data_dict) + x0, P0, c, d, T, Z, R, H, Q = unpack_symbolic_matrices_with_params( + mod, param_dict, data_dict + ) k_states = mod.k_states k_posdef = mod.k_posdef @@ -287,7 +293,9 @@ def make_stationary_params(data, p, d, q, P, D, Q, S): sm_sarimax = sm.tsa.SARIMAX(data, order=(p, d, q), seasonal_order=(P, D, Q, S)) res = sm_sarimax.fit(disp=False) - param_dict = dict(ar_params=[], ma_params=[], seasonal_ar_params=[], seasonal_ma_params=[]) + param_dict = dict( + ar_params=[], ma_params=[], seasonal_ar_params=[], seasonal_ma_params=[] + ) for name, param in zip(res.param_names, res.params): if name.startswith("ar.S"): diff --git a/tests/test_blackjax_smc.py b/tests/test_blackjax_smc.py index b669558fe..45f50d800 100644 --- a/tests/test_blackjax_smc.py +++ b/tests/test_blackjax_smc.py @@ -60,7 +60,7 @@ def two_gaussians(x): with pm.Model() as m: X = pm.Uniform("X", lower=-2, upper=2.0, shape=n) - llk = pm.Potential("muh", two_gaussians(X)) + pm.Potential("muh", two_gaussians(X)) return m, mu1 @@ -68,7 +68,7 @@ def two_gaussians(x): def fast_model(): with pm.Model() as m: x = pm.Normal("x", 0, 1) - y = pm.Normal("y", x, 1, observed=0) + pm.Normal("y", x, 1, observed=0) return m @@ -115,7 +115,9 @@ def test_sample_smc_blackjax(kernel, check_for_integration_steps, inner_kernel_p assert inference_data.posterior.attrs[attribute] == value for diagnostic in ["lambda_evolution", "log_likelihood_increments"]: - assert inference_data.posterior.attrs[diagnostic].shape == (iterations_to_diagnose,) + assert inference_data.posterior.attrs[diagnostic].shape == ( + iterations_to_diagnose, + ) for diagnostic in ["ancestors_evolution", "weights_evolution"]: assert inference_data.posterior.attrs[diagnostic].shape == ( @@ -134,33 +136,41 @@ def test_blackjax_particles_from_pymc_population_univariate(): model = fast_model() population = {"x": np.array([2, 3, 4])} blackjax_particles = blackjax_particles_from_pymc_population(model, population) - jax.tree.map(np.testing.assert_allclose, blackjax_particles, [np.array([[2], [3], [4]])]) + jax.tree.map( + np.testing.assert_allclose, blackjax_particles, [np.array([[2], [3], [4]])] + ) def test_blackjax_particles_from_pymc_population_multivariate(): with pm.Model() as model: x = pm.Normal("x", 0, 1) z = pm.Normal("z", 0, 1) - y = pm.Normal("y", x + z, 1, observed=0) + pm.Normal("y", x + z, 1, observed=0) - population = {"x": np.array([0.34614613, 1.09163261, -0.44526825]), "z": np.array([1, 2, 3])} + population = { + "x": np.array([0.34614613, 1.09163261, -0.44526825]), + "z": np.array([1, 2, 3]), + } blackjax_particles = blackjax_particles_from_pymc_population(model, population) jax.tree.map( np.testing.assert_allclose, blackjax_particles, - [np.array([[0.34614613], [1.09163261], [-0.44526825]]), np.array([[1], [2], [3]])], + [ + np.array([[0.34614613], [1.09163261], [-0.44526825]]), + np.array([[1], [2], [3]]), + ], ) def simple_multivariable_model(): """ A simple model that has a multivariate variable, - a has more than one variable (multivariable) + and has more than one variable (multivariable) """ with pm.Model() as model: - x = pm.Normal("x", 0, 1, shape=2) + pm.Normal("x", 0, 1, shape=2) z = pm.Normal("z", 0, 1) - y = pm.Normal("y", z, 1, observed=0) + pm.Normal("y", z, 1, observed=0) return model @@ -182,7 +192,9 @@ def test_arviz_from_particles(): with model: inference_data = arviz_from_particles(model, particles) - assert inference_data.posterior.sizes == Frozen({"chain": 1, "draw": 3, "x_dim_0": 2}) + assert inference_data.posterior.sizes == Frozen( + {"chain": 1, "draw": 3, "x_dim_0": 2} + ) assert inference_data.posterior.data_vars.dtypes == Frozen( {"x": dtype("float64"), "z": dtype("float64")} ) diff --git a/tests/test_histogram_approximation.py b/tests/test_histogram_approximation.py index 59521d534..b8da3be5f 100644 --- a/tests/test_histogram_approximation.py +++ b/tests/test_histogram_approximation.py @@ -60,7 +60,9 @@ def test_histogram_init_discrete(use_dask, min_count, ndims): dask = pytest.importorskip("dask") dask_df = pytest.importorskip("dask.dataframe") data = dask_df.from_array(data) - histogram = pmx.distributions.histogram_utils.discrete_histogram(data, min_count=min_count) + histogram = pmx.distributions.histogram_utils.discrete_histogram( + data, min_count=min_count + ) if use_dask: (histogram,) = dask.compute(histogram) assert isinstance(histogram, dict) @@ -81,16 +83,16 @@ def test_histogram_init_discrete(use_dask, min_count, ndims): def test_histogram_approx_cont(use_dask, ndims): data = np.random.randn(*(10000, *(2,) * (ndims - 1))) if use_dask: - dask = pytest.importorskip("dask") + pytest.importorskip("dask") dask_df = pytest.importorskip("dask.dataframe") data = dask_df.from_array(data) with pm.Model(): m = pm.Normal("m") s = pm.HalfNormal("s", size=2 if ndims > 1 else 1) - pot = pmx.distributions.histogram_utils.histogram_approximation( + pmx.distributions.histogram_utils.histogram_approximation( "histogram_potential", pm.Normal.dist(m, s), observed=data, n_quantiles=1000 ) - trace = pm.sample(10, tune=0) # very fast + pm.sample(10, tune=0) # very fast @pytest.mark.parametrize("use_dask", [True, False]) @@ -98,12 +100,12 @@ def test_histogram_approx_cont(use_dask, ndims): def test_histogram_approx_discrete(use_dask, ndims): data = np.random.randint(0, 100, size=(10000, *(2,) * (ndims - 1))) if use_dask: - dask = pytest.importorskip("dask") + pytest.importorskip("dask") dask_df = pytest.importorskip("dask.dataframe") data = dask_df.from_array(data) with pm.Model(): s = pm.Exponential("s", 1.0, size=2 if ndims > 1 else 1) - pot = pmx.distributions.histogram_utils.histogram_approximation( + pmx.distributions.histogram_utils.histogram_approximation( "histogram_potential", pm.Poisson.dist(s), observed=data, min_count=10 ) - trace = pm.sample(10, tune=0) # very fast + pm.sample(10, tune=0) # very fast diff --git a/tests/test_laplace.py b/tests/test_laplace.py index 49e5614b2..836d97185 100644 --- a/tests/test_laplace.py +++ b/tests/test_laplace.py @@ -25,7 +25,6 @@ + "To suppress this warning set `negate_output=False`:FutureWarning", ) def test_laplace(): - # Example originates from Bayesian Data Analyses, 3rd Edition # By Andrew Gelman, John Carlin, Hal Stern, David Dunson, # Aki Vehtari, and Donald Rubin. @@ -38,7 +37,7 @@ def test_laplace(): with pm.Model() as m: logsigma = pm.Uniform("logsigma", 1, 100) mu = pm.Uniform("mu", -10000, 10000) - yobs = pm.Normal("y", mu=mu, sigma=pm.math.exp(logsigma), observed=y) + pm.Normal("y", mu=mu, sigma=pm.math.exp(logsigma), observed=y) vars = [mu, logsigma] idata = pmx.fit( @@ -67,7 +66,6 @@ def test_laplace(): + "To suppress this warning set `negate_output=False`:FutureWarning", ) def test_laplace_only_fit(): - # Example originates from Bayesian Data Analyses, 3rd Edition # By Andrew Gelman, John Carlin, Hal Stern, David Dunson, # Aki Vehtari, and Donald Rubin. @@ -79,7 +77,7 @@ def test_laplace_only_fit(): with pm.Model() as m: logsigma = pm.Uniform("logsigma", 1, 100) mu = pm.Uniform("mu", -10000, 10000) - yobs = pm.Normal("y", mu=mu, sigma=pm.math.exp(logsigma), observed=y) + pm.Normal("y", mu=mu, sigma=pm.math.exp(logsigma), observed=y) vars = [mu, logsigma] idata = pmx.fit( @@ -105,22 +103,20 @@ def test_laplace_only_fit(): + "To suppress this warning set `negate_output=False`:FutureWarning", ) def test_laplace_subset_of_rv(recwarn): - # Example originates from Bayesian Data Analyses, 3rd Edition # By Andrew Gelman, John Carlin, Hal Stern, David Dunson, # Aki Vehtari, and Donald Rubin. # See section. 4.1 y = np.array([2642, 3503, 4358], dtype=np.float64) - n = y.size with pm.Model() as m: logsigma = pm.Uniform("logsigma", 1, 100) mu = pm.Uniform("mu", -10000, 10000) - yobs = pm.Normal("y", mu=mu, sigma=pm.math.exp(logsigma), observed=y) + pm.Normal("y", mu=mu, sigma=pm.math.exp(logsigma), observed=y) vars = [mu] - idata = pmx.fit( + pmx.fit( method="laplace", vars=vars, draws=None, diff --git a/tests/test_linearmodel.py b/tests/test_linearmodel.py index d969dbef2..6938a67a1 100644 --- a/tests/test_linearmodel.py +++ b/tests/test_linearmodel.py @@ -75,7 +75,8 @@ def fitted_linear_model_instance(toy_X, toy_y): @pytest.mark.skipif( - sys.platform == "win32", reason="Permissions for temp files not granted on windows CI." + sys.platform == "win32", + reason="Permissions for temp files not granted on windows CI.", ) def test_save_load(fitted_linear_model_instance): model = fitted_linear_model_instance @@ -123,9 +124,15 @@ def test_parameter_fit(toy_X, toy_y, toy_actual_params): model = LinearModel(sampler_config=sampler_config) model.fit(toy_X, toy_y, random_seed=312) fit_params = model.idata.posterior.mean() - np.testing.assert_allclose(fit_params["intercept"], toy_actual_params["intercept"], rtol=0.1) - np.testing.assert_allclose(fit_params["slope"], toy_actual_params["slope"], rtol=0.1) - np.testing.assert_allclose(fit_params["σ_model_fmc"], toy_actual_params["obs_error"], rtol=0.1) + np.testing.assert_allclose( + fit_params["intercept"], toy_actual_params["intercept"], rtol=0.1 + ) + np.testing.assert_allclose( + fit_params["slope"], toy_actual_params["slope"], rtol=0.1 + ) + np.testing.assert_allclose( + fit_params["σ_model_fmc"], toy_actual_params["obs_error"], rtol=0.1 + ) def test_predict(fitted_linear_model_instance): @@ -154,12 +161,14 @@ def test_predict_posterior(fitted_linear_model_instance, combined): @pytest.mark.parametrize("combined", [True, False]) def test_sample_prior_predictive(samples, combined, toy_X, toy_y): model = LinearModel() - prior_pred = model.sample_prior_predictive(toy_X, toy_y, samples, combined=combined)[ - model.output_var - ] + prior_pred = model.sample_prior_predictive( + toy_X, toy_y, samples, combined=combined + )[model.output_var] draws = model.sampler_config["draws"] if samples is None else samples chains = 1 - expected_shape = (len(toy_X), chains * draws) if combined else (chains, draws, len(toy_X)) + expected_shape = ( + (len(toy_X), chains * draws) if combined else (chains, draws, len(toy_X)) + ) assert prior_pred.shape == expected_shape # TODO: check that extend_idata has the expected effect @@ -179,13 +188,17 @@ def test_id(): model = LinearModel(model_config=model_config, sampler_config=sampler_config) expected_id = hashlib.sha256( - str(model_config.values()).encode() + model.version.encode() + model._model_type.encode() + str(model_config.values()).encode() + + model.version.encode() + + model._model_type.encode() ).hexdigest()[:16] assert model.id == expected_id -@pytest.mark.skipif(not sklearn_available, reason="scikit-learn package is not available.") +@pytest.mark.skipif( + not sklearn_available, reason="scikit-learn package is not available." +) def test_pipeline_integration(toy_X, toy_y): model_config = { "intercept": {"loc": 0, "scale": 2}, @@ -198,7 +211,9 @@ def test_pipeline_integration(toy_X, toy_y): ("input_scaling", StandardScaler()), ( "linear_model", - TransformedTargetRegressor(LinearModel(model_config), transformer=StandardScaler()), + TransformedTargetRegressor( + LinearModel(model_config), transformer=StandardScaler() + ), ), ] ) diff --git a/tests/test_model_builder.py b/tests/test_model_builder.py index 775f27302..8744028be 100644 --- a/tests/test_model_builder.py +++ b/tests/test_model_builder.py @@ -57,7 +57,9 @@ def get_unfitted_model_instance(X, y): "obs_error": 2, } model = test_ModelBuilder( - model_config=model_config, sampler_config=sampler_config, test_parameter="test_paramter" + model_config=model_config, + sampler_config=sampler_config, + test_parameter="test_paramter", ) # Do the things that `model.fit` does except sample to create idata. model._generate_and_preprocess_model_data(X, y.values.flatten()) @@ -116,7 +118,7 @@ def build_model(self, X: pd.DataFrame, y: pd.Series, model_config=None): obs_error = pm.HalfNormal("σ_model_fmc", obs_error) # observed data - output = pm.Normal("output", a + b * x, obs_error, shape=x.shape, observed=y_data) + pm.Normal("output", a + b * x, obs_error, shape=x.shape, observed=y_data) def _save_input_params(self, idata): idata.attrs["test_paramter"] = json.dumps(self.test_parameter) @@ -168,7 +170,8 @@ def test_save_input_params(fitted_model_instance): @pytest.mark.skipif( - sys.platform == "win32", reason="Permissions for temp files not granted on windows CI." + sys.platform == "win32", + reason="Permissions for temp files not granted on windows CI.", ) def test_save_load(fitted_model_instance): temp = tempfile.NamedTemporaryFile(mode="w", encoding="utf-8", delete=False) @@ -205,8 +208,10 @@ def test_empty_sampler_config_fit(toy_X, toy_y): def test_fit(fitted_model_instance): - prediction_data = pd.DataFrame({"input": np.random.uniform(low=0, high=1, size=100)}) - pred = fitted_model_instance.predict(prediction_data["input"]) + prediction_data = pd.DataFrame( + {"input": np.random.uniform(low=0, high=1, size=100)} + ) + fitted_model_instance.predict(prediction_data["input"]) post_pred = fitted_model_instance.sample_posterior_predictive( prediction_data["input"], extend_idata=True, combined=True ) @@ -226,7 +231,7 @@ def test_predict(fitted_model_instance): prediction_data = pd.DataFrame({"input": x_pred}) pred = fitted_model_instance.predict(prediction_data["input"]) # Perform elementwise comparison using numpy - assert type(pred) == np.ndarray + assert type(pred) is np.ndarray assert len(pred) > 0 @@ -261,7 +266,9 @@ def test_sample_xxx_extend_idata_param(fitted_model_instance, group, extend_idat else: # group == "posterior_predictive": prediction_method = fitted_model_instance.sample_posterior_predictive - pred = prediction_method(prediction_data["input"], combined=False, extend_idata=extend_idata) + pred = prediction_method( + prediction_data["input"], combined=False, extend_idata=extend_idata + ) pred_unstacked = pred[output_var].values idata_now = fitted_model_instance.idata[group][output_var].values diff --git a/tests/test_pathfinder.py b/tests/test_pathfinder.py index 3ddd4a4fb..f82c965c3 100644 --- a/tests/test_pathfinder.py +++ b/tests/test_pathfinder.py @@ -28,12 +28,12 @@ def test_pathfinder(): y = np.array([28.0, 8.0, -3.0, 7.0, -1.0, 1.0, 18.0, 12.0]) sigma = np.array([15.0, 10.0, 16.0, 11.0, 9.0, 11.0, 10.0, 18.0]) - with pm.Model() as model: + with pm.Model(): mu = pm.Normal("mu", mu=0.0, sigma=10.0) tau = pm.HalfCauchy("tau", 5.0) theta = pm.Normal("theta", mu=0, sigma=1, shape=J) - obs = pm.Normal("obs", mu=mu + tau * theta, sigma=sigma, shape=J, observed=y) + pm.Normal("obs", mu=mu + tau * theta, sigma=sigma, shape=J, observed=y) idata = pmx.fit(method="pathfinder", random_seed=41) diff --git a/tests/test_prior_from_trace.py b/tests/test_prior_from_trace.py index f6bcd3663..7a6b05203 100644 --- a/tests/test_prior_from_trace.py +++ b/tests/test_prior_from_trace.py @@ -33,7 +33,10 @@ dict(name="a", transform=transforms.log, dims=None), ), (("a", dict(name="b")), dict(name="b", transform=None, dims=None)), - (("a", dict(name="b", dims="test")), dict(name="b", transform=None, dims="test")), + ( + ("a", dict(name="b", dims="test")), + dict(name="b", transform=None, dims="test"), + ), (("a", ("test",)), dict(name="a", transform=None, dims=("test",))), ], ) @@ -149,18 +152,18 @@ def test_mean_chol(flat_info): def test_mvn_prior_from_flat_info(flat_info, coords, param_cfg): with pm.Model(coords=coords) as model: - priors = pmx.utils.prior._mvn_prior_from_flat_info("trace_prior_", flat_info) - test_prior = pm.sample_prior_predictive(1) + pmx.utils.prior._mvn_prior_from_flat_info("trace_prior_", flat_info) + pm.sample_prior_predictive(1) names = [p["name"] for p in param_cfg.values()] assert set(model.named_vars) == {"trace_prior_", *names} def test_prior_from_idata(idata, user_param_cfg, coords, param_cfg): with pm.Model(coords=coords) as model: - priors = pmx.utils.prior.prior_from_idata( + pmx.utils.prior.prior_from_idata( idata, var_names=user_param_cfg[0], **user_param_cfg[1] ) - test_prior = pm.sample_prior_predictive(1) + pm.sample_prior_predictive(1) names = [p["name"] for p in param_cfg.values()] assert set(model.named_vars) == {"trace_prior_", *names} diff --git a/tests/test_splines.py b/tests/test_splines.py index d5eab9b50..496eccee5 100644 --- a/tests/test_splines.py +++ b/tests/test_splines.py @@ -44,7 +44,9 @@ def test_spline_construction(dtype, sparse): @pytest.mark.parametrize("shape", [(100,), (100, 5)]) @pytest.mark.parametrize("sparse", [True, False]) -@pytest.mark.parametrize("points", [dict(n=1001), dict(eval_points=np.linspace(0, 1, 1001))]) +@pytest.mark.parametrize( + "points", [dict(n=1001), dict(eval_points=np.linspace(0, 1, 1001))] +) def test_interpolation_api(shape, sparse, points): x = np.random.randn(*shape) yt = pmx.utils.spline.bspline_interpolation(x, **points, sparse=sparse) @@ -55,7 +57,11 @@ def test_interpolation_api(shape, sparse, points): @pytest.mark.parametrize( "params", [ - (dict(sparse="foo", n=100, degree=1), TypeError, "sparse should be True or False"), + ( + dict(sparse="foo", n=100, degree=1), + TypeError, + "sparse should be True or False", + ), (dict(n=100, degree=0.5), TypeError, "degree should be integer"), ( dict(n=100, eval_points=np.linspace(0, 1), degree=1),