Skip to content

Commit a7c1090

Browse files
committed
fixing linting
Signed-off-by: Nathaniel <[email protected]>
1 parent 2578cce commit a7c1090

File tree

3 files changed

+23
-22
lines changed

3 files changed

+23
-22
lines changed

causalpy/pymc_models.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -677,7 +677,7 @@ def build_model( # type: ignore
677677
Z: np.ndarray,
678678
y: np.ndarray,
679679
t: np.ndarray,
680-
coords: Dict[str, Any],
680+
coords: dict[str, Any],
681681
priors,
682682
vs_prior_type=None,
683683
vs_hyperparams=None,
@@ -721,7 +721,8 @@ def build_model( # type: ignore
721721
warnings.warn(
722722
"Variable selection priors specified. "
723723
"The 'mus' and 'sigmas' in the priors dict will be ignored "
724-
"for beta coefficients. Only 'eta' and 'lkj_sd' will be used."
724+
"for beta coefficients. Only 'eta' and 'lkj_sd' will be used.",
725+
stacklevel=2,
725726
)
726727

727728
# Create coefficient priors

causalpy/variable_selection_priors.py

Lines changed: 16 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@
2020
top of the pymc-extras Prior infrastructure.
2121
"""
2222

23-
from typing import Any, Dict, Optional, Union
23+
from typing import Any
2424

2525
import numpy as np
2626
import pandas as pd
@@ -30,7 +30,7 @@
3030

3131

3232
def _relaxed_bernoulli_transform(
33-
p: Union[float, pt.TensorVariable], temperature: float = 0.1
33+
p: float | pt.TensorVariable, temperature: float = 0.1
3434
):
3535
"""
3636
Transform function for relaxed (continuous) Bernoulli distribution.
@@ -100,7 +100,7 @@ def __init__(
100100
pi_beta: float = 2,
101101
slab_sigma: float = 2,
102102
temperature: float = 0.1,
103-
dims: Optional[Union[str, tuple]] = None,
103+
dims: str | tuple | None = None,
104104
):
105105
self.pi_alpha = pi_alpha
106106
self.pi_beta = pi_beta
@@ -176,11 +176,11 @@ class HorseshoePrior:
176176

177177
def __init__(
178178
self,
179-
tau0: Optional[float] = None,
179+
tau0: float | None = None,
180180
nu: float = 3,
181181
c2_alpha: float = 2,
182182
c2_beta: float = 2,
183-
dims: Optional[Union[str, tuple]] = None,
183+
dims: str | tuple | None = None,
184184
):
185185
self.tau0 = tau0
186186
self.nu = nu
@@ -277,7 +277,7 @@ class VariableSelectionPrior:
277277
... beta = vs_prior.create_prior(name="beta", n_params=5, dims="features")
278278
"""
279279

280-
def __init__(self, prior_type: str, hyperparams: Optional[Dict[str, Any]] = None):
280+
def __init__(self, prior_type: str, hyperparams: dict[str, Any] | None = None):
281281
"""Initialize the variable selection prior factory."""
282282
self.prior_type = prior_type.lower()
283283
self.hyperparams = hyperparams or {}
@@ -292,8 +292,8 @@ def __init__(self, prior_type: str, hyperparams: Optional[Dict[str, Any]] = None
292292
self._prior_instance = None
293293

294294
def _get_default_hyperparams(
295-
self, n_params: int, X: Optional[np.ndarray] = None
296-
) -> Dict[str, Any]:
295+
self, n_params: int, X: np.ndarray | None = None
296+
) -> dict[str, Any]:
297297
"""
298298
Get default hyperparameters for the chosen prior type.
299299
@@ -346,10 +346,10 @@ def create_prior(
346346
self,
347347
name: str,
348348
n_params: int,
349-
dims: Optional[Union[str, tuple]] = None,
350-
X: Optional[np.ndarray] = None,
351-
hyperparams: Optional[Dict[str, Any]] = None,
352-
) -> Union[pm.Deterministic, pm.Distribution]:
349+
dims: str | tuple | None = None,
350+
X: np.ndarray | None = None,
351+
hyperparams: dict[str, Any] | None = None,
352+
) -> pm.Deterministic | pm.Distribution:
353353
"""
354354
Create the specified prior on a coefficient vector.
355355
@@ -559,10 +559,10 @@ def create_variable_selection_prior(
559559
prior_type: str,
560560
name: str,
561561
n_params: int,
562-
dims: Optional[Union[str, tuple]] = None,
563-
X: Optional[np.ndarray] = None,
564-
hyperparams: Optional[Dict[str, Any]] = None,
565-
) -> Union[pm.Deterministic, pm.Distribution]:
562+
dims: str | tuple | None = None,
563+
X: np.ndarray | None = None,
564+
hyperparams: dict[str, Any] | None = None,
565+
) -> pm.Deterministic | pm.Distribution:
566566
"""
567567
Convenience function to create a variable selection prior in one call.
568568

docs/source/_static/interrogate_badge.svg

Lines changed: 4 additions & 4 deletions
Loading

0 commit comments

Comments
 (0)