diff --git a/notebooks/Exponential Trend Smoothing.ipynb b/notebooks/Exponential Trend Smoothing.ipynb index dece07ea6..42afbbf81 100644 --- a/notebooks/Exponential Trend Smoothing.ipynb +++ b/notebooks/Exponential Trend Smoothing.ipynb @@ -15,18 +15,19 @@ "\n", "warnings.filterwarnings(action=\"ignore\", message=\"The RandomType SharedVariables\")\n", "\n", - "import matplotlib.pyplot as plt\n", - "from matplotlib.dates import DayLocator, MonthLocator, YearLocator, DateFormatter\n", - "\n", - "import pymc_extras.statespace as pmss\n", - "import pymc as pm\n", + "import os\n", "\n", - "import preliz as pz\n", - "import pandas as pd\n", "import arviz as az\n", - "import yfinance as yf\n", + "import matplotlib.pyplot as plt\n", "import numpy as np\n", - "import os\n", + "import pandas as pd\n", + "import preliz as pz\n", + "import pymc as pm\n", + "import yfinance as yf\n", + "\n", + "from matplotlib.dates import DateFormatter, DayLocator, MonthLocator, YearLocator\n", + "\n", + "import pymc_extras.statespace as pmss\n", "\n", "plt.rcParams.update(\n", " {\n", @@ -61,11 +62,14 @@ "metadata": {}, "outputs": [], "source": [ - "import ipywidgets as widgets\n", - "from IPython.display import display\n", "from functools import partial\n", + "\n", + "import ipywidgets as widgets\n", "import pytensor\n", "import pytensor.tensor as pt\n", + "\n", + "from IPython.display import display\n", + "\n", "from pymc_extras.statespace.filters.distributions import LinearGaussianStateSpace\n", "\n", "\n", @@ -344,8 +348,6 @@ "metadata": {}, "outputs": [], "source": [ - "import datetime\n", - "\n", "FORCE_UPDATE = False\n", "\n", "if not os.path.isdir(\"data\"):\n", diff --git a/notebooks/Making a Custom Statespace Model.ipynb b/notebooks/Making a Custom Statespace Model.ipynb index f41818158..c95b188a7 100644 --- a/notebooks/Making a Custom Statespace Model.ipynb +++ b/notebooks/Making a Custom Statespace Model.ipynb @@ -15,13 +15,13 @@ } ], "source": [ - "import numpy as np\n", - "import matplotlib.pyplot as plt\n", "import arviz as az\n", - "\n", - "from pymc_extras.statespace.core.statespace import PyMCStateSpace\n", + "import matplotlib.pyplot as plt\n", + "import numpy as np\n", + "import pymc as pm\n", "import pytensor.tensor as pt\n", - "import pymc as pm" + "\n", + "from pymc_extras.statespace.core.statespace import PyMCStateSpace" ] }, { @@ -45,7 +45,7 @@ "\n", "\n", "def print_model_ssm(mod, how=\"eval\"):\n", - " nice_heading = f'{\"name\":<20}{\"__repr__\":<50}{\"shape\":<10}{\"value\":<20}'\n", + " nice_heading = f\"{'name':<20}{'__repr__':<50}{'shape':<10}{'value':<20}\"\n", " print(nice_heading)\n", " print(\"=\" * len(nice_heading))\n", " if how == \"eval\":\n", @@ -1270,7 +1270,7 @@ ], "source": [ "az.plot_posterior(\n", - " idata, var_names=[\"ar_params\", \"sigma_x\"], ref_val=true_ar.tolist() + [true_sigma_x]\n", + " idata, var_names=[\"ar_params\", \"sigma_x\"], ref_val=[*true_ar.tolist(), true_sigma_x]\n", ");" ] }, @@ -1333,13 +1333,12 @@ "metadata": {}, "outputs": [], "source": [ + "from pymc_extras.statespace.models.utilities import make_default_coords\n", "from pymc_extras.statespace.utils.constants import (\n", - " ALL_STATE_DIM,\n", " ALL_STATE_AUX_DIM,\n", - " OBS_STATE_DIM,\n", + " ALL_STATE_DIM,\n", " SHOCK_DIM,\n", ")\n", - "from pymc_extras.statespace.models.utilities import make_default_coords\n", "\n", "\n", "class AutoRegressiveThree(PyMCStateSpace):\n", diff --git a/notebooks/SARMA Example.ipynb b/notebooks/SARMA Example.ipynb index 2d9067fef..3cb2078ae 100644 --- a/notebooks/SARMA Example.ipynb +++ b/notebooks/SARMA Example.ipynb @@ -31,20 +31,19 @@ "\n", "jax.config.update(\"jax_platform_name\", \"cpu\")\n", "\n", - "import pymc as pm\n", - "from pytensor import tensor as pt\n", + "import warnings\n", "\n", "import arviz as az\n", - "import statsmodels.api as sm\n", "import matplotlib.pyplot as plt\n", "import numpy as np\n", "import pandas as pd\n", - "from scipy import stats\n", + "import pymc as pm\n", + "import statsmodels.api as sm\n", "\n", - "import pymc_extras.statespace as pmss\n", "from pymc.model.transform.optimization import freeze_dims_and_data\n", + "from pytensor import tensor as pt\n", "\n", - "import warnings\n", + "import pymc_extras.statespace as pmss\n", "\n", "warnings.filterwarnings(action=\"ignore\", message=\"The RandomType SharedVariables\")\n", "\n", @@ -2697,8 +2696,8 @@ "source": [ "fig, ax = plt.subplots()\n", "post = az.extract(post_pred).map(np.exp)\n", - "hdi = az.hdi(post_pred.map(np.exp))[f\"predicted_posterior_observed\"]\n", - "post[f\"predicted_posterior_observed\"].isel(observed_state=0).mean(dim=\"sample\").plot.line(\n", + "hdi = az.hdi(post_pred.map(np.exp))[\"predicted_posterior_observed\"]\n", + "post[\"predicted_posterior_observed\"].isel(observed_state=0).mean(dim=\"sample\").plot.line(\n", " x=\"time\", ax=ax, add_legend=False, label=\"Posterior Mean, Predicted\"\n", ")\n", "ax.fill_between(\n", diff --git a/notebooks/Structural Timeseries Modeling.ipynb b/notebooks/Structural Timeseries Modeling.ipynb index 941b32afd..0be6b6952 100644 --- a/notebooks/Structural Timeseries Modeling.ipynb +++ b/notebooks/Structural Timeseries Modeling.ipynb @@ -20,18 +20,20 @@ } ], "source": [ - "from pymc_extras.statespace import structural as st\n", - "from pymc_extras.statespace.utils.constants import SHORT_NAME_TO_LONG\n", + "import warnings\n", + "\n", + "import arviz as az\n", "import matplotlib.pyplot as plt\n", + "import numpy as np\n", + "import pandas as pd\n", + "import preliz as pz\n", "import pymc as pm\n", - "import arviz as az\n", "import pytensor.tensor as pt\n", - "import preliz as pz\n", "\n", - "import numpy as np\n", - "import pandas as pd\n", "from patsy import dmatrix\n", - "import warnings\n", + "\n", + "from pymc_extras.statespace import structural as st\n", + "from pymc_extras.statespace.utils.constants import SHORT_NAME_TO_LONG\n", "\n", "warnings.filterwarnings(\"ignore\", category=UserWarning, message=\"The RandomType SharedVariables\")\n", "\n", @@ -76,13 +78,14 @@ }, "outputs": [], "source": [ - "from pymc_extras.statespace.filters.distributions import LinearGaussianStateSpace\n", "from pymc.pytensorf import inputvars\n", "\n", + "from pymc_extras.statespace.filters.distributions import LinearGaussianStateSpace\n", + "\n", "\n", "def make_numpy_function(mod):\n", " mod = mod.build(verbose=False)\n", - " data = pt.matrix(\"data\", shape=(None, 1))\n", + " pt.matrix(\"data\", shape=(None, 1))\n", " steps = pt.iscalar(\"steps\")\n", " x0, _, c, d, T, Z, R, H, Q = mod._unpack_statespace_with_placeholders()\n", " sequence_names = [x.name for x in [c, d] if x.ndim == 2]\n", @@ -3113,7 +3116,6 @@ }, "outputs": [], "source": [ - "from nutpie import transform_adapter\n", "import nutpie as ntp\n", "\n", "# Uncomment to see arguments for with_transform_adapt\n", diff --git a/notebooks/VARMAX Example.ipynb b/notebooks/VARMAX Example.ipynb index 7bcadbd13..0dc19386e 100644 --- a/notebooks/VARMAX Example.ipynb +++ b/notebooks/VARMAX Example.ipynb @@ -11,21 +11,21 @@ "\n", "jax.config.update(\"jax_platform_name\", \"cpu\")\n", "\n", + "import sys\n", + "\n", + "import arviz as az\n", + "import matplotlib.pyplot as plt\n", "import numpy as np\n", - "import statsmodels.api as sm\n", "import pandas as pd\n", - "\n", "import pymc as pm\n", "import pytensor.tensor as pt\n", - "import arviz as az\n", - "\n", - "import matplotlib.pyplot as plt\n", - "import sys\n", + "import statsmodels.api as sm\n", "\n", "sys.path.append(\"..\")\n", - "import pymc_extras.statespace as pmss\n", "import re\n", "\n", + "import pymc_extras.statespace as pmss\n", + "\n", "config = {\n", " \"figure.figsize\": [12.0, 4.0],\n", " \"figure.dpi\": 72.0 * 2,\n", @@ -841,7 +841,7 @@ " new_labels = []\n", " for label in axis.yaxis.get_majorticklabels():\n", " old_text = \"[\" + label.get_text().split(\"[\")[-1]\n", - " labels = eval(re.sub(\"([\\d\\w]+)\", '\"\\g<1>\"', old_text))\n", + " labels = eval(re.sub(r\"([\\d\\w]+)\", r'\"\\g<1>\"', old_text))\n", " lag, other_var = labels\n", " new_text = f\"L{lag}.{other_var}\"\n", " new_labels.append(new_text)\n", diff --git a/notebooks/discrete_markov_chain.ipynb b/notebooks/discrete_markov_chain.ipynb index c267d42fe..a4bb48a42 100644 --- a/notebooks/discrete_markov_chain.ipynb +++ b/notebooks/discrete_markov_chain.ipynb @@ -27,13 +27,13 @@ "outputs": [], "source": [ "import arviz as az\n", + "import matplotlib.pyplot as plt\n", "import numpy as np\n", + "import pandas as pd\n", "import pymc as pm\n", "import pytensor\n", "import pytensor.tensor as pt\n", - "import pandas as pd\n", "import statsmodels.api as sm\n", - "import matplotlib.pyplot as plt\n", "\n", "from matplotlib import ticker as mtick\n", "\n", @@ -525,7 +525,7 @@ " \"dates\": dta_hamilton.index,\n", " \"obs_dates\": dta_hamilton.index[order:],\n", " \"states\": [\"State_1\", \"State_2\"],\n", - " \"ar_params\": [f\"L{i+1}.phi\" for i in range(order)],\n", + " \"ar_params\": [f\"L{i + 1}.phi\" for i in range(order)],\n", "}\n", "\n", "with pm.Model(coords=coords) as hmm:\n", diff --git a/notebooks/latent_approx.ipynb b/notebooks/latent_approx.ipynb index ed1202c0a..242422ff2 100644 --- a/notebooks/latent_approx.ipynb +++ b/notebooks/latent_approx.ipynb @@ -25,13 +25,9 @@ } ], "source": [ - "import arviz as az\n", "import matplotlib.pyplot as plt\n", "import numpy as np\n", - "import pymc as pm\n", - "from aesara import tensor as at\n", - "\n", - "import pymc_extras as pmx" + "import pymc as pm" ] }, { @@ -42,7 +38,7 @@ "outputs": [], "source": [ "from pymc_extras.gp import HSGP, KarhunenLoeveExpansion, ProjectedProcess\n", - "from pymc_extras.gp.latent_approx import ExpQuad, Matern12, Matern32, Matern52" + "from pymc_extras.gp.latent_approx import ExpQuad" ] }, { diff --git a/notebooks/marginalized_changepoint_model.ipynb b/notebooks/marginalized_changepoint_model.ipynb index ac9f34814..c298ba2ce 100644 --- a/notebooks/marginalized_changepoint_model.ipynb +++ b/notebooks/marginalized_changepoint_model.ipynb @@ -6,11 +6,12 @@ "metadata": {}, "outputs": [], "source": [ - "import pymc as pm\n", - "from pymc_extras.model.marginal.marginal_model import MarginalModel\n", - "import pandas as pd\n", + "import arviz as az\n", "import numpy as np\n", - "import arviz as az" + "import pandas as pd\n", + "import pymc as pm\n", + "\n", + "from pymc_extras.model.marginal.marginal_model import MarginalModel" ] }, { diff --git a/pymc_extras/distributions/__init__.py b/pymc_extras/distributions/__init__.py index 783154ba3..2afabad2b 100644 --- a/pymc_extras/distributions/__init__.py +++ b/pymc_extras/distributions/__init__.py @@ -29,14 +29,14 @@ from pymc_extras.distributions.transforms import PartialOrder __all__ = [ + "R2D2M2CP", + "BetaNegativeBinomial", "Chi", - "Maxwell", "DiscreteMarkovChain", - "GeneralizedPoisson", - "BetaNegativeBinomial", "GenExtreme", - "R2D2M2CP", + "GeneralizedPoisson", + "Maxwell", + "PartialOrder", "Skellam", "histogram_approximation", - "PartialOrder", ] diff --git a/pymc_extras/distributions/histogram_utils.py b/pymc_extras/distributions/histogram_utils.py index 70b5c43d8..5cf899032 100644 --- a/pymc_extras/distributions/histogram_utils.py +++ b/pymc_extras/distributions/histogram_utils.py @@ -18,7 +18,7 @@ from numpy.typing import ArrayLike -__all__ = ["quantile_histogram", "discrete_histogram", "histogram_approximation"] +__all__ = ["discrete_histogram", "histogram_approximation", "quantile_histogram"] def quantile_histogram( diff --git a/pymc_extras/inference/__init__.py b/pymc_extras/inference/__init__.py index a536f91e6..c213f4208 100644 --- a/pymc_extras/inference/__init__.py +++ b/pymc_extras/inference/__init__.py @@ -17,4 +17,4 @@ from pymc_extras.inference.laplace_approx.laplace import fit_laplace from pymc_extras.inference.pathfinder.pathfinder import fit_pathfinder -__all__ = ["fit", "fit_pathfinder", "fit_laplace", "find_MAP"] +__all__ = ["find_MAP", "fit", "fit_laplace", "fit_pathfinder"] diff --git a/pymc_extras/printing.py b/pymc_extras/printing.py index 8409e3426..9f82238ae 100644 --- a/pymc_extras/printing.py +++ b/pymc_extras/printing.py @@ -166,7 +166,7 @@ def model_table( for var in group: var_name = var.name - sep = f'[b]{" ~" if (var in model.basic_RVs) else " ="}[/b]' + sep = f"[b]{' ~' if (var in model.basic_RVs) else ' ='}[/b]" var_expr = variable_expression(model, var, truncate_deterministic) dims_expr = dims_expression(model, var) if dims_expr == "[]": diff --git a/pymc_extras/statespace/__init__.py b/pymc_extras/statespace/__init__.py index 314b11717..ecba1be80 100644 --- a/pymc_extras/statespace/__init__.py +++ b/pymc_extras/statespace/__init__.py @@ -5,9 +5,9 @@ from pymc_extras.statespace.models.VARMAX import BayesianVARMAX __all__ = [ - "compile_statespace", - "structural", "BayesianETS", "BayesianSARIMA", "BayesianVARMAX", + "compile_statespace", + "structural", ] diff --git a/pymc_extras/statespace/core/__init__.py b/pymc_extras/statespace/core/__init__.py index 6e7c67d64..723337a1e 100644 --- a/pymc_extras/statespace/core/__init__.py +++ b/pymc_extras/statespace/core/__init__.py @@ -4,4 +4,4 @@ from pymc_extras.statespace.core.statespace import PyMCStateSpace from pymc_extras.statespace.core.compile import compile_statespace -__all__ = ["PytensorRepresentation", "PyMCStateSpace", "compile_statespace"] +__all__ = ["PyMCStateSpace", "PytensorRepresentation", "compile_statespace"] diff --git a/pymc_extras/statespace/core/representation.py b/pymc_extras/statespace/core/representation.py index 450108d38..f583f97d7 100644 --- a/pymc_extras/statespace/core/representation.py +++ b/pymc_extras/statespace/core/representation.py @@ -153,19 +153,19 @@ class PytensorRepresentation: """ __slots__ = ( + "design", + "initial_state", + "initial_state_cov", "k_endog", - "k_states", "k_posdef", - "shapes", - "design", - "obs_intercept", + "k_states", "obs_cov", - "transition", - "state_intercept", + "obs_intercept", "selection", + "shapes", "state_cov", - "initial_state", - "initial_state_cov", + "state_intercept", + "transition", ) def __init__( diff --git a/pymc_extras/statespace/core/statespace.py b/pymc_extras/statespace/core/statespace.py index 0b0f3cd43..ed473ca63 100644 --- a/pymc_extras/statespace/core/statespace.py +++ b/pymc_extras/statespace/core/statespace.py @@ -60,7 +60,7 @@ def _validate_filter_arg(filter_arg): if filter_arg.lower() not in FILTER_OUTPUT_TYPES: raise ValueError( - f'filter_output should be one of {", ".join(FILTER_OUTPUT_TYPES)}, received {filter_arg}' + f"filter_output should be one of {', '.join(FILTER_OUTPUT_TYPES)}, received {filter_arg}" ) diff --git a/pymc_extras/statespace/filters/__init__.py b/pymc_extras/statespace/filters/__init__.py index f76dea8d4..355926990 100644 --- a/pymc_extras/statespace/filters/__init__.py +++ b/pymc_extras/statespace/filters/__init__.py @@ -7,9 +7,9 @@ from pymc_extras.statespace.filters.kalman_smoother import KalmanSmoother __all__ = [ - "StandardFilter", - "UnivariateFilter", "KalmanSmoother", - "SquareRootFilter", "LinearGaussianStateSpace", + "SquareRootFilter", + "StandardFilter", + "UnivariateFilter", ] diff --git a/pymc_extras/statespace/models/SARIMAX.py b/pymc_extras/statespace/models/SARIMAX.py index c35b977e5..94b3e3763 100644 --- a/pymc_extras/statespace/models/SARIMAX.py +++ b/pymc_extras/statespace/models/SARIMAX.py @@ -224,7 +224,7 @@ def __init__( if state_structure not in SARIMAX_STATE_STRUCTURES: raise ValueError( f"Got invalid argument {state_structure} for state structure, expected one of " - f'{", ".join(SARIMAX_STATE_STRUCTURES)}' + f"{', '.join(SARIMAX_STATE_STRUCTURES)}" ) if state_structure == "interpretable" and (self.d + self.D) > 0: diff --git a/pymc_extras/statespace/models/__init__.py b/pymc_extras/statespace/models/__init__.py index 6a94cd7e2..4e47c978e 100644 --- a/pymc_extras/statespace/models/__init__.py +++ b/pymc_extras/statespace/models/__init__.py @@ -3,4 +3,4 @@ from pymc_extras.statespace.models.SARIMAX import BayesianSARIMA from pymc_extras.statespace.models.VARMAX import BayesianVARMAX -__all__ = ["structural", "BayesianSARIMA", "BayesianVARMAX", "BayesianETS"] +__all__ = ["BayesianETS", "BayesianSARIMA", "BayesianVARMAX", "structural"] diff --git a/pymc_extras/statespace/models/structural/__init__.py b/pymc_extras/statespace/models/structural/__init__.py index 57cb6d7ac..f0bfb2f0a 100644 --- a/pymc_extras/statespace/models/structural/__init__.py +++ b/pymc_extras/statespace/models/structural/__init__.py @@ -11,11 +11,11 @@ ) __all__ = [ - "LevelTrendComponent", - "MeasurementError", "AutoregressiveComponent", - "TimeSeasonality", + "CycleComponent", "FrequencySeasonality", + "LevelTrendComponent", + "MeasurementError", "RegressionComponent", - "CycleComponent", + "TimeSeasonality", ] diff --git a/pymc_extras/statespace/utils/data_tools.py b/pymc_extras/statespace/utils/data_tools.py index d9f2d6842..cbc5d517c 100644 --- a/pymc_extras/statespace/utils/data_tools.py +++ b/pymc_extras/statespace/utils/data_tools.py @@ -53,7 +53,7 @@ def _validate_data_shape(data_shape, n_obs, obs_coords=None, check_col_names=Fal if len(missing_cols) > 0: raise ValueError( "Columns of DataFrame provided as data do not match state names. The following states were" - f'not found: {", ".join(missing_cols)}. This may result in unexpected results in complex' + f"not found: {', '.join(missing_cols)}. This may result in unexpected results in complex" f"statespace models" ) diff --git a/tests/distributions/__init__.py b/tests/distributions/__init__.py index 2b7d71582..e4171b37c 100644 --- a/tests/distributions/__init__.py +++ b/tests/distributions/__init__.py @@ -16,4 +16,4 @@ from pymc_extras.distributions import histogram_utils from pymc_extras.distributions.histogram_utils import histogram_approximation -__all__ = ["histogram_utils", "histogram_approximation"] +__all__ = ["histogram_approximation", "histogram_utils"] diff --git a/tests/statespace/core/test_statespace.py b/tests/statespace/core/test_statespace.py index 06aec484e..507cfa996 100644 --- a/tests/statespace/core/test_statespace.py +++ b/tests/statespace/core/test_statespace.py @@ -884,8 +884,7 @@ def test_invalid_scenarios(): # Giving a list, tuple, or Series when a matrix of data is expected should always raise with pytest.raises( ValueError, - match="Scenario data for variable 'a' has the wrong number of columns. " - "Expected 2, got 1", + match="Scenario data for variable 'a' has the wrong number of columns. Expected 2, got 1", ): for data_type in [list, tuple, pd.Series]: ss_mod._validate_scenario_data(data_type(np.zeros(10))) @@ -894,15 +893,14 @@ def test_invalid_scenarios(): # Providing irrevelant data raises with pytest.raises( ValueError, - match="Scenario data provided for variable 'jk lol', which is not an exogenous " "variable", + match="Scenario data provided for variable 'jk lol', which is not an exogenous variable", ): ss_mod._validate_scenario_data({"jk lol": np.zeros(10)}) # Incorrect 2nd dimension of a non-dataframe with pytest.raises( ValueError, - match="Scenario data for variable 'a' has the wrong number of columns. Expected " - "2, got 1", + match="Scenario data for variable 'a' has the wrong number of columns. Expected 2, got 1", ): scenario = np.zeros(10).tolist() ss_mod._validate_scenario_data(scenario) diff --git a/tests/statespace/models/test_SARIMAX.py b/tests/statespace/models/test_SARIMAX.py index d04303b2f..8e6a543b3 100644 --- a/tests/statespace/models/test_SARIMAX.py +++ b/tests/statespace/models/test_SARIMAX.py @@ -137,7 +137,7 @@ "state_star_3", "state_star_4", ], - ["data", "data_star"] + [f"state_star_{i+1}" for i in range(26)], + ["data", "data_star"] + [f"state_star_{i + 1}" for i in range(26)], ] test_orders = [ diff --git a/tests/test_histogram_approximation.py b/tests/test_histogram_approximation.py index 968b4571b..213096509 100644 --- a/tests/test_histogram_approximation.py +++ b/tests/test_histogram_approximation.py @@ -44,7 +44,7 @@ def test_histogram_init_cont(use_dask, zero_inflation, ndims): assert histogram["mid"].shape == (size,) + (1,) * len(data.shape[1:]) assert histogram["lower"].shape == (size,) + (1,) * len(data.shape[1:]) assert histogram["upper"].shape == (size,) + (1,) * len(data.shape[1:]) - assert histogram["count"].shape == (size,) + data.shape[1:] + assert histogram["count"].shape == (size, *data.shape[1:]) assert (histogram["count"].sum(0) == 10000).all() if zero_inflation: (histogram["count"][0] == 100).all() @@ -71,7 +71,7 @@ def test_histogram_init_discrete(use_dask, min_count, ndims): else: size = len(u) assert histogram["mid"].shape == (size,) + (1,) * len(data.shape[1:]) - assert histogram["count"].shape == (size,) + data.shape[1:] + assert histogram["count"].shape == (size, *data.shape[1:]) if not min_count: assert (histogram["count"].sum(0) == 10000).all()