Skip to content

Commit 32a42b4

Browse files
pre-commit (#563)
1 parent b5aa9bf commit 32a42b4

25 files changed

+98
-101
lines changed

notebooks/Exponential Trend Smoothing.ipynb

Lines changed: 15 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -15,18 +15,19 @@
1515
"\n",
1616
"warnings.filterwarnings(action=\"ignore\", message=\"The RandomType SharedVariables\")\n",
1717
"\n",
18-
"import matplotlib.pyplot as plt\n",
19-
"from matplotlib.dates import DayLocator, MonthLocator, YearLocator, DateFormatter\n",
20-
"\n",
21-
"import pymc_extras.statespace as pmss\n",
22-
"import pymc as pm\n",
18+
"import os\n",
2319
"\n",
24-
"import preliz as pz\n",
25-
"import pandas as pd\n",
2620
"import arviz as az\n",
27-
"import yfinance as yf\n",
21+
"import matplotlib.pyplot as plt\n",
2822
"import numpy as np\n",
29-
"import os\n",
23+
"import pandas as pd\n",
24+
"import preliz as pz\n",
25+
"import pymc as pm\n",
26+
"import yfinance as yf\n",
27+
"\n",
28+
"from matplotlib.dates import DateFormatter, DayLocator, MonthLocator, YearLocator\n",
29+
"\n",
30+
"import pymc_extras.statespace as pmss\n",
3031
"\n",
3132
"plt.rcParams.update(\n",
3233
" {\n",
@@ -61,11 +62,14 @@
6162
"metadata": {},
6263
"outputs": [],
6364
"source": [
64-
"import ipywidgets as widgets\n",
65-
"from IPython.display import display\n",
6665
"from functools import partial\n",
66+
"\n",
67+
"import ipywidgets as widgets\n",
6768
"import pytensor\n",
6869
"import pytensor.tensor as pt\n",
70+
"\n",
71+
"from IPython.display import display\n",
72+
"\n",
6973
"from pymc_extras.statespace.filters.distributions import LinearGaussianStateSpace\n",
7074
"\n",
7175
"\n",
@@ -344,8 +348,6 @@
344348
"metadata": {},
345349
"outputs": [],
346350
"source": [
347-
"import datetime\n",
348-
"\n",
349351
"FORCE_UPDATE = False\n",
350352
"\n",
351353
"if not os.path.isdir(\"data\"):\n",

notebooks/Making a Custom Statespace Model.ipynb

Lines changed: 9 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -15,13 +15,13 @@
1515
}
1616
],
1717
"source": [
18-
"import numpy as np\n",
19-
"import matplotlib.pyplot as plt\n",
2018
"import arviz as az\n",
21-
"\n",
22-
"from pymc_extras.statespace.core.statespace import PyMCStateSpace\n",
19+
"import matplotlib.pyplot as plt\n",
20+
"import numpy as np\n",
21+
"import pymc as pm\n",
2322
"import pytensor.tensor as pt\n",
24-
"import pymc as pm"
23+
"\n",
24+
"from pymc_extras.statespace.core.statespace import PyMCStateSpace"
2525
]
2626
},
2727
{
@@ -45,7 +45,7 @@
4545
"\n",
4646
"\n",
4747
"def print_model_ssm(mod, how=\"eval\"):\n",
48-
" nice_heading = f'{\"name\":<20}{\"__repr__\":<50}{\"shape\":<10}{\"value\":<20}'\n",
48+
" nice_heading = f\"{'name':<20}{'__repr__':<50}{'shape':<10}{'value':<20}\"\n",
4949
" print(nice_heading)\n",
5050
" print(\"=\" * len(nice_heading))\n",
5151
" if how == \"eval\":\n",
@@ -1270,7 +1270,7 @@
12701270
],
12711271
"source": [
12721272
"az.plot_posterior(\n",
1273-
" idata, var_names=[\"ar_params\", \"sigma_x\"], ref_val=true_ar.tolist() + [true_sigma_x]\n",
1273+
" idata, var_names=[\"ar_params\", \"sigma_x\"], ref_val=[*true_ar.tolist(), true_sigma_x]\n",
12741274
");"
12751275
]
12761276
},
@@ -1333,13 +1333,12 @@
13331333
"metadata": {},
13341334
"outputs": [],
13351335
"source": [
1336+
"from pymc_extras.statespace.models.utilities import make_default_coords\n",
13361337
"from pymc_extras.statespace.utils.constants import (\n",
1337-
" ALL_STATE_DIM,\n",
13381338
" ALL_STATE_AUX_DIM,\n",
1339-
" OBS_STATE_DIM,\n",
1339+
" ALL_STATE_DIM,\n",
13401340
" SHOCK_DIM,\n",
13411341
")\n",
1342-
"from pymc_extras.statespace.models.utilities import make_default_coords\n",
13431342
"\n",
13441343
"\n",
13451344
"class AutoRegressiveThree(PyMCStateSpace):\n",

notebooks/SARMA Example.ipynb

Lines changed: 7 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -31,20 +31,19 @@
3131
"\n",
3232
"jax.config.update(\"jax_platform_name\", \"cpu\")\n",
3333
"\n",
34-
"import pymc as pm\n",
35-
"from pytensor import tensor as pt\n",
34+
"import warnings\n",
3635
"\n",
3736
"import arviz as az\n",
38-
"import statsmodels.api as sm\n",
3937
"import matplotlib.pyplot as plt\n",
4038
"import numpy as np\n",
4139
"import pandas as pd\n",
42-
"from scipy import stats\n",
40+
"import pymc as pm\n",
41+
"import statsmodels.api as sm\n",
4342
"\n",
44-
"import pymc_extras.statespace as pmss\n",
4543
"from pymc.model.transform.optimization import freeze_dims_and_data\n",
44+
"from pytensor import tensor as pt\n",
4645
"\n",
47-
"import warnings\n",
46+
"import pymc_extras.statespace as pmss\n",
4847
"\n",
4948
"warnings.filterwarnings(action=\"ignore\", message=\"The RandomType SharedVariables\")\n",
5049
"\n",
@@ -2697,8 +2696,8 @@
26972696
"source": [
26982697
"fig, ax = plt.subplots()\n",
26992698
"post = az.extract(post_pred).map(np.exp)\n",
2700-
"hdi = az.hdi(post_pred.map(np.exp))[f\"predicted_posterior_observed\"]\n",
2701-
"post[f\"predicted_posterior_observed\"].isel(observed_state=0).mean(dim=\"sample\").plot.line(\n",
2699+
"hdi = az.hdi(post_pred.map(np.exp))[\"predicted_posterior_observed\"]\n",
2700+
"post[\"predicted_posterior_observed\"].isel(observed_state=0).mean(dim=\"sample\").plot.line(\n",
27022701
" x=\"time\", ax=ax, add_legend=False, label=\"Posterior Mean, Predicted\"\n",
27032702
")\n",
27042703
"ax.fill_between(\n",

notebooks/Structural Timeseries Modeling.ipynb

Lines changed: 12 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -20,18 +20,20 @@
2020
}
2121
],
2222
"source": [
23-
"from pymc_extras.statespace import structural as st\n",
24-
"from pymc_extras.statespace.utils.constants import SHORT_NAME_TO_LONG\n",
23+
"import warnings\n",
24+
"\n",
25+
"import arviz as az\n",
2526
"import matplotlib.pyplot as plt\n",
27+
"import numpy as np\n",
28+
"import pandas as pd\n",
29+
"import preliz as pz\n",
2630
"import pymc as pm\n",
27-
"import arviz as az\n",
2831
"import pytensor.tensor as pt\n",
29-
"import preliz as pz\n",
3032
"\n",
31-
"import numpy as np\n",
32-
"import pandas as pd\n",
3333
"from patsy import dmatrix\n",
34-
"import warnings\n",
34+
"\n",
35+
"from pymc_extras.statespace import structural as st\n",
36+
"from pymc_extras.statespace.utils.constants import SHORT_NAME_TO_LONG\n",
3537
"\n",
3638
"warnings.filterwarnings(\"ignore\", category=UserWarning, message=\"The RandomType SharedVariables\")\n",
3739
"\n",
@@ -76,13 +78,14 @@
7678
},
7779
"outputs": [],
7880
"source": [
79-
"from pymc_extras.statespace.filters.distributions import LinearGaussianStateSpace\n",
8081
"from pymc.pytensorf import inputvars\n",
8182
"\n",
83+
"from pymc_extras.statespace.filters.distributions import LinearGaussianStateSpace\n",
84+
"\n",
8285
"\n",
8386
"def make_numpy_function(mod):\n",
8487
" mod = mod.build(verbose=False)\n",
85-
" data = pt.matrix(\"data\", shape=(None, 1))\n",
88+
" pt.matrix(\"data\", shape=(None, 1))\n",
8689
" steps = pt.iscalar(\"steps\")\n",
8790
" x0, _, c, d, T, Z, R, H, Q = mod._unpack_statespace_with_placeholders()\n",
8891
" sequence_names = [x.name for x in [c, d] if x.ndim == 2]\n",
@@ -3113,7 +3116,6 @@
31133116
},
31143117
"outputs": [],
31153118
"source": [
3116-
"from nutpie import transform_adapter\n",
31173119
"import nutpie as ntp\n",
31183120
"\n",
31193121
"# Uncomment to see arguments for with_transform_adapt\n",

notebooks/VARMAX Example.ipynb

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -11,21 +11,21 @@
1111
"\n",
1212
"jax.config.update(\"jax_platform_name\", \"cpu\")\n",
1313
"\n",
14+
"import sys\n",
15+
"\n",
16+
"import arviz as az\n",
17+
"import matplotlib.pyplot as plt\n",
1418
"import numpy as np\n",
15-
"import statsmodels.api as sm\n",
1619
"import pandas as pd\n",
17-
"\n",
1820
"import pymc as pm\n",
1921
"import pytensor.tensor as pt\n",
20-
"import arviz as az\n",
21-
"\n",
22-
"import matplotlib.pyplot as plt\n",
23-
"import sys\n",
22+
"import statsmodels.api as sm\n",
2423
"\n",
2524
"sys.path.append(\"..\")\n",
26-
"import pymc_extras.statespace as pmss\n",
2725
"import re\n",
2826
"\n",
27+
"import pymc_extras.statespace as pmss\n",
28+
"\n",
2929
"config = {\n",
3030
" \"figure.figsize\": [12.0, 4.0],\n",
3131
" \"figure.dpi\": 72.0 * 2,\n",
@@ -841,7 +841,7 @@
841841
" new_labels = []\n",
842842
" for label in axis.yaxis.get_majorticklabels():\n",
843843
" old_text = \"[\" + label.get_text().split(\"[\")[-1]\n",
844-
" labels = eval(re.sub(\"([\\d\\w]+)\", '\"\\g<1>\"', old_text))\n",
844+
" labels = eval(re.sub(r\"([\\d\\w]+)\", r'\"\\g<1>\"', old_text))\n",
845845
" lag, other_var = labels\n",
846846
" new_text = f\"L{lag}.{other_var}\"\n",
847847
" new_labels.append(new_text)\n",

notebooks/discrete_markov_chain.ipynb

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -27,13 +27,13 @@
2727
"outputs": [],
2828
"source": [
2929
"import arviz as az\n",
30+
"import matplotlib.pyplot as plt\n",
3031
"import numpy as np\n",
32+
"import pandas as pd\n",
3133
"import pymc as pm\n",
3234
"import pytensor\n",
3335
"import pytensor.tensor as pt\n",
34-
"import pandas as pd\n",
3536
"import statsmodels.api as sm\n",
36-
"import matplotlib.pyplot as plt\n",
3737
"\n",
3838
"from matplotlib import ticker as mtick\n",
3939
"\n",
@@ -525,7 +525,7 @@
525525
" \"dates\": dta_hamilton.index,\n",
526526
" \"obs_dates\": dta_hamilton.index[order:],\n",
527527
" \"states\": [\"State_1\", \"State_2\"],\n",
528-
" \"ar_params\": [f\"L{i+1}.phi\" for i in range(order)],\n",
528+
" \"ar_params\": [f\"L{i + 1}.phi\" for i in range(order)],\n",
529529
"}\n",
530530
"\n",
531531
"with pm.Model(coords=coords) as hmm:\n",

notebooks/latent_approx.ipynb

Lines changed: 2 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -25,13 +25,9 @@
2525
}
2626
],
2727
"source": [
28-
"import arviz as az\n",
2928
"import matplotlib.pyplot as plt\n",
3029
"import numpy as np\n",
31-
"import pymc as pm\n",
32-
"from aesara import tensor as at\n",
33-
"\n",
34-
"import pymc_extras as pmx"
30+
"import pymc as pm"
3531
]
3632
},
3733
{
@@ -42,7 +38,7 @@
4238
"outputs": [],
4339
"source": [
4440
"from pymc_extras.gp import HSGP, KarhunenLoeveExpansion, ProjectedProcess\n",
45-
"from pymc_extras.gp.latent_approx import ExpQuad, Matern12, Matern32, Matern52"
41+
"from pymc_extras.gp.latent_approx import ExpQuad"
4642
]
4743
},
4844
{

notebooks/marginalized_changepoint_model.ipynb

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -6,11 +6,12 @@
66
"metadata": {},
77
"outputs": [],
88
"source": [
9-
"import pymc as pm\n",
10-
"from pymc_extras.model.marginal.marginal_model import MarginalModel\n",
11-
"import pandas as pd\n",
9+
"import arviz as az\n",
1210
"import numpy as np\n",
13-
"import arviz as az"
11+
"import pandas as pd\n",
12+
"import pymc as pm\n",
13+
"\n",
14+
"from pymc_extras.model.marginal.marginal_model import MarginalModel"
1415
]
1516
},
1617
{

pymc_extras/distributions/__init__.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -29,14 +29,14 @@
2929
from pymc_extras.distributions.transforms import PartialOrder
3030

3131
__all__ = [
32+
"R2D2M2CP",
33+
"BetaNegativeBinomial",
3234
"Chi",
33-
"Maxwell",
3435
"DiscreteMarkovChain",
35-
"GeneralizedPoisson",
36-
"BetaNegativeBinomial",
3736
"GenExtreme",
38-
"R2D2M2CP",
37+
"GeneralizedPoisson",
38+
"Maxwell",
39+
"PartialOrder",
3940
"Skellam",
4041
"histogram_approximation",
41-
"PartialOrder",
4242
]

pymc_extras/distributions/histogram_utils.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@
1818

1919
from numpy.typing import ArrayLike
2020

21-
__all__ = ["quantile_histogram", "discrete_histogram", "histogram_approximation"]
21+
__all__ = ["discrete_histogram", "histogram_approximation", "quantile_histogram"]
2222

2323

2424
def quantile_histogram(

0 commit comments

Comments
 (0)