Skip to content

Commit 69c4307

Browse files
[pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
1 parent 1ab4a49 commit 69c4307

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

52 files changed

+284
-244
lines changed

notebooks/SARMA Example.ipynb

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1554,7 +1554,7 @@
15541554
" hdi_forecast.coords[\"time\"].values,\n",
15551555
" *hdi_forecast.isel(observed_state=0).values.T,\n",
15561556
" alpha=0.25,\n",
1557-
" color=\"tab:blue\"\n",
1557+
" color=\"tab:blue\",\n",
15581558
" )\n",
15591559
"ax.set_title(\"Porcupine Graph of 10-Period Forecasts (parameters estimated on all data)\")\n",
15601560
"plt.show()"
@@ -2692,7 +2692,7 @@
26922692
" *forecast_hdi.values.T,\n",
26932693
" label=\"Forecast 94% HDI\",\n",
26942694
" color=\"tab:orange\",\n",
2695-
" alpha=0.25\n",
2695+
" alpha=0.25,\n",
26962696
")\n",
26972697
"ax.legend()\n",
26982698
"plt.show()"

notebooks/Structural Timeseries Modeling.ipynb

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1657,7 +1657,7 @@
16571657
" nile.index,\n",
16581658
" *component_hdi.smoothed_posterior.sel(state=state).values.T,\n",
16591659
" color=\"tab:blue\",\n",
1660-
" alpha=0.15\n",
1660+
" alpha=0.15,\n",
16611661
" )\n",
16621662
" axis.set_title(state.title())"
16631663
]
@@ -1706,7 +1706,7 @@
17061706
" *hdi.smoothed_posterior.sum(dim=\"state\").values.T,\n",
17071707
" color=\"tab:blue\",\n",
17081708
" alpha=0.15,\n",
1709-
" label=\"HDI 94%\"\n",
1709+
" label=\"HDI 94%\",\n",
17101710
")\n",
17111711
"ax.legend()\n",
17121712
"plt.show()"
@@ -2750,7 +2750,7 @@
27502750
"ax.fill_between(\n",
27512751
" blossom_data.index,\n",
27522752
" *hdi_post.predicted_posterior_observed.isel(observed_state=0).values.T,\n",
2753-
" alpha=0.25\n",
2753+
" alpha=0.25,\n",
27542754
")\n",
27552755
"blossom_data.plot(ax=ax)"
27562756
]

pymc_experimental/distributions/continuous.py

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -19,10 +19,9 @@
1919
The imports from pymc are not fully replicated here: add imports as necessary.
2020
"""
2121

22-
from typing import Tuple, Union
23-
2422
import numpy as np
2523
import pytensor.tensor as pt
24+
2625
from pymc import ChiSquared, CustomDist
2726
from pymc.distributions import transforms
2827
from pymc.distributions.dist_math import check_parameters
@@ -39,19 +38,19 @@ class GenExtremeRV(RandomVariable):
3938
name: str = "Generalized Extreme Value"
4039
signature = "(),(),()->()"
4140
dtype: str = "floatX"
42-
_print_name: Tuple[str, str] = ("Generalized Extreme Value", "\\operatorname{GEV}")
41+
_print_name: tuple[str, str] = ("Generalized Extreme Value", "\\operatorname{GEV}")
4342

4443
def __call__(self, mu=0.0, sigma=1.0, xi=0.0, size=None, **kwargs) -> TensorVariable:
4544
return super().__call__(mu, sigma, xi, size=size, **kwargs)
4645

4746
@classmethod
4847
def rng_fn(
4948
cls,
50-
rng: Union[np.random.RandomState, np.random.Generator],
49+
rng: np.random.RandomState | np.random.Generator,
5150
mu: np.ndarray,
5251
sigma: np.ndarray,
5352
xi: np.ndarray,
54-
size: Tuple[int, ...],
53+
size: tuple[int, ...],
5554
) -> np.ndarray:
5655
# Notice negative here, since remainder of GenExtreme is based on Coles parametrization
5756
return stats.genextreme.rvs(c=-xi, loc=mu, scale=sigma, random_state=rng, size=size)

pymc_experimental/distributions/discrete.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414

1515
import numpy as np
1616
import pymc as pm
17+
1718
from pymc.distributions.dist_math import betaln, check_parameters, factln, logpow
1819
from pymc.distributions.shape_utils import rv_size_is_none
1920
from pytensor import tensor as pt

pymc_experimental/distributions/histogram_utils.py

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -13,18 +13,17 @@
1313
# limitations under the License.
1414

1515

16-
from typing import Dict
17-
1816
import numpy as np
1917
import pymc as pm
18+
2019
from numpy.typing import ArrayLike
2120

2221
__all__ = ["quantile_histogram", "discrete_histogram", "histogram_approximation"]
2322

2423

2524
def quantile_histogram(
2625
data: ArrayLike, n_quantiles=1000, zero_inflation=False
27-
) -> Dict[str, ArrayLike]:
26+
) -> dict[str, ArrayLike]:
2827
try:
2928
import xhistogram.core
3029
except ImportError as e:
@@ -67,7 +66,7 @@ def quantile_histogram(
6766
return result
6867

6968

70-
def discrete_histogram(data: ArrayLike, min_count=None) -> Dict[str, ArrayLike]:
69+
def discrete_histogram(data: ArrayLike, min_count=None) -> dict[str, ArrayLike]:
7170
try:
7271
import xhistogram.core
7372
except ImportError as e:

pymc_experimental/distributions/multivariate/r2d2m2cp.py

Lines changed: 17 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414

1515

1616
from collections import namedtuple
17-
from typing import Sequence, Tuple, Union
17+
from collections.abc import Sequence
1818

1919
import numpy as np
2020
import pymc as pm
@@ -26,8 +26,8 @@
2626
def _psivar2musigma(
2727
psi: pt.TensorVariable,
2828
explained_var: pt.TensorVariable,
29-
psi_mask: Union[pt.TensorLike, None],
30-
) -> Tuple[pt.TensorVariable, pt.TensorVariable]:
29+
psi_mask: pt.TensorLike | None,
30+
) -> tuple[pt.TensorVariable, pt.TensorVariable]:
3131
sign = pt.sign(psi - 0.5)
3232
if psi_mask is not None:
3333
# any computation might be ignored for ~psi_mask
@@ -55,7 +55,7 @@ def _R2D2M2CP_beta(
5555
psi: pt.TensorVariable,
5656
*,
5757
psi_mask,
58-
dims: Union[str, Sequence[str]],
58+
dims: str | Sequence[str],
5959
centered=False,
6060
) -> pt.TensorVariable:
6161
"""R2D2M2CP beta prior.
@@ -120,7 +120,7 @@ def _R2D2M2CP_beta(
120120
def _broadcast_as_dims(
121121
*values: np.ndarray,
122122
dims: Sequence[str],
123-
) -> Union[Tuple[np.ndarray, ...], np.ndarray]:
123+
) -> tuple[np.ndarray, ...] | np.ndarray:
124124
model = pm.modelcontext(None)
125125
shape = [len(model.coords[d]) for d in dims]
126126
ret = tuple(np.broadcast_to(v, shape) for v in values)
@@ -135,7 +135,7 @@ def _psi_masked(
135135
positive_probs_std: pt.TensorLike,
136136
*,
137137
dims: Sequence[str],
138-
) -> Tuple[Union[pt.TensorLike, None], pt.TensorVariable]:
138+
) -> tuple[pt.TensorLike | None, pt.TensorVariable]:
139139
if not (
140140
isinstance(positive_probs, pt.Constant) and isinstance(positive_probs_std, pt.Constant)
141141
):
@@ -172,10 +172,10 @@ def _psi_masked(
172172

173173
def _psi(
174174
positive_probs: pt.TensorLike,
175-
positive_probs_std: Union[pt.TensorLike, None],
175+
positive_probs_std: pt.TensorLike | None,
176176
*,
177177
dims: Sequence[str],
178-
) -> Tuple[Union[pt.TensorLike, None], pt.TensorVariable]:
178+
) -> tuple[pt.TensorLike | None, pt.TensorVariable]:
179179
if positive_probs_std is not None:
180180
mask, psi = _psi_masked(
181181
positive_probs=pt.as_tensor(positive_probs),
@@ -194,9 +194,9 @@ def _psi(
194194

195195

196196
def _phi(
197-
variables_importance: Union[pt.TensorLike, None],
198-
variance_explained: Union[pt.TensorLike, None],
199-
importance_concentration: Union[pt.TensorLike, None],
197+
variables_importance: pt.TensorLike | None,
198+
variance_explained: pt.TensorLike | None,
199+
importance_concentration: pt.TensorLike | None,
200200
*,
201201
dims: Sequence[str],
202202
) -> pt.TensorVariable:
@@ -233,12 +233,12 @@ def R2D2M2CP(
233233
*,
234234
dims: Sequence[str],
235235
r2: pt.TensorLike,
236-
variables_importance: Union[pt.TensorLike, None] = None,
237-
variance_explained: Union[pt.TensorLike, None] = None,
238-
importance_concentration: Union[pt.TensorLike, None] = None,
239-
r2_std: Union[pt.TensorLike, None] = None,
240-
positive_probs: Union[pt.TensorLike, None] = 0.5,
241-
positive_probs_std: Union[pt.TensorLike, None] = None,
236+
variables_importance: pt.TensorLike | None = None,
237+
variance_explained: pt.TensorLike | None = None,
238+
importance_concentration: pt.TensorLike | None = None,
239+
r2_std: pt.TensorLike | None = None,
240+
positive_probs: pt.TensorLike | None = 0.5,
241+
positive_probs_std: pt.TensorLike | None = None,
242242
centered: bool = False,
243243
) -> R2D2M2CPOut:
244244
"""R2D2M2CP Prior.

pymc_experimental/distributions/timeseries.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,10 @@
11
import warnings
2-
from typing import List, Union
32

43
import numpy as np
54
import pymc as pm
65
import pytensor
76
import pytensor.tensor as pt
7+
88
from pymc.distributions.dist_math import check_parameters
99
from pymc.distributions.distribution import (
1010
Distribution,
@@ -26,7 +26,7 @@
2626
from pytensor.tensor.random.op import RandomVariable
2727

2828

29-
def _make_outputs_info(n_lags: int, init_dist: Distribution) -> List[Union[Distribution, dict]]:
29+
def _make_outputs_info(n_lags: int, init_dist: Distribution) -> list[Distribution | dict]:
3030
"""
3131
Two cases are needed for outputs_info in the scans used by DiscreteMarkovRv. If n_lags = 1, we need to throw away
3232
the first dimension of init_dist_ or else markov_chain will have shape (steps, 1, *batch_size) instead of

pymc_experimental/gp/latent_approx.py

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -12,11 +12,11 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414
from functools import partial
15-
from typing import Optional
1615

1716
import numpy as np
1817
import pymc as pm
1918
import pytensor.tensor as pt
19+
2020
from pymc.gp.util import JITTER_DEFAULT, stabilize
2121
from pytensor.tensor.linalg import cholesky, solve_triangular
2222

@@ -33,7 +33,7 @@ class ProjectedProcess(pm.gp.Latent):
3333
## AKA: DTC
3434
def __init__(
3535
self,
36-
n_inducing: Optional[int] = None,
36+
n_inducing: int | None = None,
3737
*,
3838
mean_func=pm.gp.mean.Zero(),
3939
cov_func=pm.gp.cov.Constant(0.0),
@@ -59,20 +59,22 @@ def prior(
5959
self,
6060
name: str,
6161
X: np.ndarray,
62-
X_inducing: Optional[np.ndarray] = None,
62+
X_inducing: np.ndarray | None = None,
6363
jitter: float = JITTER_DEFAULT,
6464
**kwargs,
6565
) -> np.ndarray:
6666
"""
6767
Builds the GP prior with optional inducing points locations.
6868
69-
Parameters:
69+
Parameters
70+
----------
7071
- name: Name for the GP variable.
7172
- X: Input data.
7273
- X_inducing: Optional. Inducing points for the GP.
7374
- jitter: Jitter to ensure numerical stability.
7475
75-
Returns:
76+
Returns
77+
-------
7678
- GP function
7779
"""
7880
# Check if X is a numpy array

pymc_experimental/inference/fit.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,6 @@ def fit(method, **kwargs):
4040
return fit_pathfinder(**kwargs)
4141

4242
if method == "laplace":
43-
4443
from pymc_experimental.inference.laplace import laplace
4544

4645
return laplace(**kwargs)

pymc_experimental/inference/laplace.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -13,13 +13,14 @@
1313
# limitations under the License.
1414

1515
import warnings
16+
1617
from collections.abc import Sequence
17-
from typing import Optional
1818

1919
import arviz as az
2020
import numpy as np
2121
import pymc as pm
2222
import xarray as xr
23+
2324
from arviz import dict_to_dataset
2425
from pymc.backends.arviz import (
2526
coords_and_dims_for_inferencedata,
@@ -33,9 +34,9 @@
3334

3435
def laplace(
3536
vars: Sequence[Variable],
36-
draws: Optional[int] = 1000,
37+
draws: int | None = 1000,
3738
model=None,
38-
random_seed: Optional[RandomSeed] = None,
39+
random_seed: RandomSeed | None = None,
3940
progressbar=True,
4041
):
4142
"""
@@ -72,7 +73,6 @@ def laplace(
7273
7374
Examples
7475
--------
75-
7676
>>> import numpy as np
7777
>>> import pymc as pm
7878
>>> import arviz as az

0 commit comments

Comments
 (0)