Skip to content

pre-commit #563

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 4 commits into from
Aug 15, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
28 changes: 15 additions & 13 deletions notebooks/Exponential Trend Smoothing.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -15,18 +15,19 @@
"\n",
"warnings.filterwarnings(action=\"ignore\", message=\"The RandomType SharedVariables\")\n",
"\n",
"import matplotlib.pyplot as plt\n",
"from matplotlib.dates import DayLocator, MonthLocator, YearLocator, DateFormatter\n",
"\n",
"import pymc_extras.statespace as pmss\n",
"import pymc as pm\n",
"import os\n",
"\n",
"import preliz as pz\n",
"import pandas as pd\n",
"import arviz as az\n",
"import yfinance as yf\n",
"import matplotlib.pyplot as plt\n",
"import numpy as np\n",
"import os\n",
"import pandas as pd\n",
"import preliz as pz\n",
"import pymc as pm\n",
"import yfinance as yf\n",
"\n",
"from matplotlib.dates import DateFormatter, DayLocator, MonthLocator, YearLocator\n",
"\n",
"import pymc_extras.statespace as pmss\n",
"\n",
"plt.rcParams.update(\n",
" {\n",
Expand Down Expand Up @@ -61,11 +62,14 @@
"metadata": {},
"outputs": [],
"source": [
"import ipywidgets as widgets\n",
"from IPython.display import display\n",
"from functools import partial\n",
"\n",
"import ipywidgets as widgets\n",
"import pytensor\n",
"import pytensor.tensor as pt\n",
"\n",
"from IPython.display import display\n",
"\n",
"from pymc_extras.statespace.filters.distributions import LinearGaussianStateSpace\n",
"\n",
"\n",
Expand Down Expand Up @@ -344,8 +348,6 @@
"metadata": {},
"outputs": [],
"source": [
"import datetime\n",
"\n",
"FORCE_UPDATE = False\n",
"\n",
"if not os.path.isdir(\"data\"):\n",
Expand Down
19 changes: 9 additions & 10 deletions notebooks/Making a Custom Statespace Model.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -15,13 +15,13 @@
}
],
"source": [
"import numpy as np\n",
"import matplotlib.pyplot as plt\n",
"import arviz as az\n",
"\n",
"from pymc_extras.statespace.core.statespace import PyMCStateSpace\n",
"import matplotlib.pyplot as plt\n",
"import numpy as np\n",
"import pymc as pm\n",
"import pytensor.tensor as pt\n",
"import pymc as pm"
"\n",
"from pymc_extras.statespace.core.statespace import PyMCStateSpace"
]
},
{
Expand All @@ -45,7 +45,7 @@
"\n",
"\n",
"def print_model_ssm(mod, how=\"eval\"):\n",
" nice_heading = f'{\"name\":<20}{\"__repr__\":<50}{\"shape\":<10}{\"value\":<20}'\n",
" nice_heading = f\"{'name':<20}{'__repr__':<50}{'shape':<10}{'value':<20}\"\n",
" print(nice_heading)\n",
" print(\"=\" * len(nice_heading))\n",
" if how == \"eval\":\n",
Expand Down Expand Up @@ -1270,7 +1270,7 @@
],
"source": [
"az.plot_posterior(\n",
" idata, var_names=[\"ar_params\", \"sigma_x\"], ref_val=true_ar.tolist() + [true_sigma_x]\n",
" idata, var_names=[\"ar_params\", \"sigma_x\"], ref_val=[*true_ar.tolist(), true_sigma_x]\n",
");"
]
},
Expand Down Expand Up @@ -1333,13 +1333,12 @@
"metadata": {},
"outputs": [],
"source": [
"from pymc_extras.statespace.models.utilities import make_default_coords\n",
"from pymc_extras.statespace.utils.constants import (\n",
" ALL_STATE_DIM,\n",
" ALL_STATE_AUX_DIM,\n",
" OBS_STATE_DIM,\n",
" ALL_STATE_DIM,\n",
" SHOCK_DIM,\n",
")\n",
"from pymc_extras.statespace.models.utilities import make_default_coords\n",
"\n",
"\n",
"class AutoRegressiveThree(PyMCStateSpace):\n",
Expand Down
15 changes: 7 additions & 8 deletions notebooks/SARMA Example.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -31,20 +31,19 @@
"\n",
"jax.config.update(\"jax_platform_name\", \"cpu\")\n",
"\n",
"import pymc as pm\n",
"from pytensor import tensor as pt\n",
"import warnings\n",
"\n",
"import arviz as az\n",
"import statsmodels.api as sm\n",
"import matplotlib.pyplot as plt\n",
"import numpy as np\n",
"import pandas as pd\n",
"from scipy import stats\n",
"import pymc as pm\n",
"import statsmodels.api as sm\n",
"\n",
"import pymc_extras.statespace as pmss\n",
"from pymc.model.transform.optimization import freeze_dims_and_data\n",
"from pytensor import tensor as pt\n",
"\n",
"import warnings\n",
"import pymc_extras.statespace as pmss\n",
"\n",
"warnings.filterwarnings(action=\"ignore\", message=\"The RandomType SharedVariables\")\n",
"\n",
Expand Down Expand Up @@ -2697,8 +2696,8 @@
"source": [
"fig, ax = plt.subplots()\n",
"post = az.extract(post_pred).map(np.exp)\n",
"hdi = az.hdi(post_pred.map(np.exp))[f\"predicted_posterior_observed\"]\n",
"post[f\"predicted_posterior_observed\"].isel(observed_state=0).mean(dim=\"sample\").plot.line(\n",
"hdi = az.hdi(post_pred.map(np.exp))[\"predicted_posterior_observed\"]\n",
"post[\"predicted_posterior_observed\"].isel(observed_state=0).mean(dim=\"sample\").plot.line(\n",
" x=\"time\", ax=ax, add_legend=False, label=\"Posterior Mean, Predicted\"\n",
")\n",
"ax.fill_between(\n",
Expand Down
22 changes: 12 additions & 10 deletions notebooks/Structural Timeseries Modeling.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -20,18 +20,20 @@
}
],
"source": [
"from pymc_extras.statespace import structural as st\n",
"from pymc_extras.statespace.utils.constants import SHORT_NAME_TO_LONG\n",
"import warnings\n",
"\n",
"import arviz as az\n",
"import matplotlib.pyplot as plt\n",
"import numpy as np\n",
"import pandas as pd\n",
"import preliz as pz\n",
"import pymc as pm\n",
"import arviz as az\n",
"import pytensor.tensor as pt\n",
"import preliz as pz\n",
"\n",
"import numpy as np\n",
"import pandas as pd\n",
"from patsy import dmatrix\n",
"import warnings\n",
"\n",
"from pymc_extras.statespace import structural as st\n",
"from pymc_extras.statespace.utils.constants import SHORT_NAME_TO_LONG\n",
"\n",
"warnings.filterwarnings(\"ignore\", category=UserWarning, message=\"The RandomType SharedVariables\")\n",
"\n",
Expand Down Expand Up @@ -76,13 +78,14 @@
},
"outputs": [],
"source": [
"from pymc_extras.statespace.filters.distributions import LinearGaussianStateSpace\n",
"from pymc.pytensorf import inputvars\n",
"\n",
"from pymc_extras.statespace.filters.distributions import LinearGaussianStateSpace\n",
"\n",
"\n",
"def make_numpy_function(mod):\n",
" mod = mod.build(verbose=False)\n",
" data = pt.matrix(\"data\", shape=(None, 1))\n",
" pt.matrix(\"data\", shape=(None, 1))\n",
" steps = pt.iscalar(\"steps\")\n",
" x0, _, c, d, T, Z, R, H, Q = mod._unpack_statespace_with_placeholders()\n",
" sequence_names = [x.name for x in [c, d] if x.ndim == 2]\n",
Expand Down Expand Up @@ -3113,7 +3116,6 @@
},
"outputs": [],
"source": [
"from nutpie import transform_adapter\n",
"import nutpie as ntp\n",
"\n",
"# Uncomment to see arguments for with_transform_adapt\n",
Expand Down
16 changes: 8 additions & 8 deletions notebooks/VARMAX Example.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -11,21 +11,21 @@
"\n",
"jax.config.update(\"jax_platform_name\", \"cpu\")\n",
"\n",
"import sys\n",
"\n",
"import arviz as az\n",
"import matplotlib.pyplot as plt\n",
"import numpy as np\n",
"import statsmodels.api as sm\n",
"import pandas as pd\n",
"\n",
"import pymc as pm\n",
"import pytensor.tensor as pt\n",
"import arviz as az\n",
"\n",
"import matplotlib.pyplot as plt\n",
"import sys\n",
"import statsmodels.api as sm\n",
"\n",
"sys.path.append(\"..\")\n",
"import pymc_extras.statespace as pmss\n",
"import re\n",
"\n",
"import pymc_extras.statespace as pmss\n",
"\n",
"config = {\n",
" \"figure.figsize\": [12.0, 4.0],\n",
" \"figure.dpi\": 72.0 * 2,\n",
Expand Down Expand Up @@ -841,7 +841,7 @@
" new_labels = []\n",
" for label in axis.yaxis.get_majorticklabels():\n",
" old_text = \"[\" + label.get_text().split(\"[\")[-1]\n",
" labels = eval(re.sub(\"([\\d\\w]+)\", '\"\\g<1>\"', old_text))\n",
" labels = eval(re.sub(r\"([\\d\\w]+)\", r'\"\\g<1>\"', old_text))\n",
" lag, other_var = labels\n",
" new_text = f\"L{lag}.{other_var}\"\n",
" new_labels.append(new_text)\n",
Expand Down
6 changes: 3 additions & 3 deletions notebooks/discrete_markov_chain.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -27,13 +27,13 @@
"outputs": [],
"source": [
"import arviz as az\n",
"import matplotlib.pyplot as plt\n",
"import numpy as np\n",
"import pandas as pd\n",
"import pymc as pm\n",
"import pytensor\n",
"import pytensor.tensor as pt\n",
"import pandas as pd\n",
"import statsmodels.api as sm\n",
"import matplotlib.pyplot as plt\n",
"\n",
"from matplotlib import ticker as mtick\n",
"\n",
Expand Down Expand Up @@ -525,7 +525,7 @@
" \"dates\": dta_hamilton.index,\n",
" \"obs_dates\": dta_hamilton.index[order:],\n",
" \"states\": [\"State_1\", \"State_2\"],\n",
" \"ar_params\": [f\"L{i+1}.phi\" for i in range(order)],\n",
" \"ar_params\": [f\"L{i + 1}.phi\" for i in range(order)],\n",
"}\n",
"\n",
"with pm.Model(coords=coords) as hmm:\n",
Expand Down
8 changes: 2 additions & 6 deletions notebooks/latent_approx.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -25,13 +25,9 @@
}
],
"source": [
"import arviz as az\n",
"import matplotlib.pyplot as plt\n",
"import numpy as np\n",
"import pymc as pm\n",
"from aesara import tensor as at\n",
"\n",
"import pymc_extras as pmx"
"import pymc as pm"
]
},
{
Expand All @@ -42,7 +38,7 @@
"outputs": [],
"source": [
"from pymc_extras.gp import HSGP, KarhunenLoeveExpansion, ProjectedProcess\n",
"from pymc_extras.gp.latent_approx import ExpQuad, Matern12, Matern32, Matern52"
"from pymc_extras.gp.latent_approx import ExpQuad"
]
},
{
Expand Down
9 changes: 5 additions & 4 deletions notebooks/marginalized_changepoint_model.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -6,11 +6,12 @@
"metadata": {},
"outputs": [],
"source": [
"import pymc as pm\n",
"from pymc_extras.model.marginal.marginal_model import MarginalModel\n",
"import pandas as pd\n",
"import arviz as az\n",
"import numpy as np\n",
"import arviz as az"
"import pandas as pd\n",
"import pymc as pm\n",
"\n",
"from pymc_extras.model.marginal.marginal_model import MarginalModel"
]
},
{
Expand Down
10 changes: 5 additions & 5 deletions pymc_extras/distributions/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,14 +29,14 @@
from pymc_extras.distributions.transforms import PartialOrder

__all__ = [
"R2D2M2CP",
"BetaNegativeBinomial",
"Chi",
"Maxwell",
"DiscreteMarkovChain",
"GeneralizedPoisson",
"BetaNegativeBinomial",
"GenExtreme",
"R2D2M2CP",
"GeneralizedPoisson",
"Maxwell",
"PartialOrder",
"Skellam",
"histogram_approximation",
"PartialOrder",
]
2 changes: 1 addition & 1 deletion pymc_extras/distributions/histogram_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@

from numpy.typing import ArrayLike

__all__ = ["quantile_histogram", "discrete_histogram", "histogram_approximation"]
__all__ = ["discrete_histogram", "histogram_approximation", "quantile_histogram"]


def quantile_histogram(
Expand Down
2 changes: 1 addition & 1 deletion pymc_extras/inference/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,4 +17,4 @@
from pymc_extras.inference.laplace_approx.laplace import fit_laplace
from pymc_extras.inference.pathfinder.pathfinder import fit_pathfinder

__all__ = ["fit", "fit_pathfinder", "fit_laplace", "find_MAP"]
__all__ = ["find_MAP", "fit", "fit_laplace", "fit_pathfinder"]
2 changes: 1 addition & 1 deletion pymc_extras/printing.py
Original file line number Diff line number Diff line change
Expand Up @@ -166,7 +166,7 @@ def model_table(

for var in group:
var_name = var.name
sep = f'[b]{" ~" if (var in model.basic_RVs) else " ="}[/b]'
sep = f"[b]{' ~' if (var in model.basic_RVs) else ' ='}[/b]"
var_expr = variable_expression(model, var, truncate_deterministic)
dims_expr = dims_expression(model, var)
if dims_expr == "[]":
Expand Down
4 changes: 2 additions & 2 deletions pymc_extras/statespace/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,9 @@
from pymc_extras.statespace.models.VARMAX import BayesianVARMAX

__all__ = [
"compile_statespace",
"structural",
"BayesianETS",
"BayesianSARIMA",
"BayesianVARMAX",
"compile_statespace",
"structural",
]
2 changes: 1 addition & 1 deletion pymc_extras/statespace/core/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,4 +4,4 @@
from pymc_extras.statespace.core.statespace import PyMCStateSpace
from pymc_extras.statespace.core.compile import compile_statespace

__all__ = ["PytensorRepresentation", "PyMCStateSpace", "compile_statespace"]
__all__ = ["PyMCStateSpace", "PytensorRepresentation", "compile_statespace"]
16 changes: 8 additions & 8 deletions pymc_extras/statespace/core/representation.py
Original file line number Diff line number Diff line change
Expand Up @@ -153,19 +153,19 @@ class PytensorRepresentation:
"""

__slots__ = (
"design",
"initial_state",
"initial_state_cov",
"k_endog",
"k_states",
"k_posdef",
"shapes",
"design",
"obs_intercept",
"k_states",
"obs_cov",
"transition",
"state_intercept",
"obs_intercept",
"selection",
"shapes",
"state_cov",
"initial_state",
"initial_state_cov",
"state_intercept",
"transition",
)

def __init__(
Expand Down
Loading
Loading