diff --git a/pymc_bart/utils.py b/pymc_bart/utils.py index cf804c5..dfb5eac 100644 --- a/pymc_bart/utils.py +++ b/pymc_bart/utils.py @@ -2,7 +2,8 @@ """Utility function for variable selection and bart interpretability.""" import warnings -from typing import Any, Callable, Optional, Union +from collections.abc import Callable +from typing import Any, TypeVar import matplotlib.pyplot as plt import numpy as np @@ -18,15 +19,15 @@ from .tree import Tree -TensorLike = Union[npt.NDArray, pt.TensorVariable] +TensorLike = TypeVar("TensorLike", npt.NDArray, pt.TensorVariable) def _sample_posterior( all_trees: list[list[Tree]], X: TensorLike, rng: np.random.Generator, - size: Optional[Union[int, tuple[int, ...]]] = None, - excluded: Optional[list[int]] = None, + size: int | tuple[int, ...] | None = None, + excluded: list[int] | None = None, shape: int = 1, ) -> npt.NDArray: """ @@ -51,7 +52,7 @@ def _sample_posterior( X = X.eval() if size is None: - size_iter: Union[list, tuple] = (1,) + size_iter: list | tuple = (1,) elif isinstance(size, int): size_iter = [size] else: @@ -78,9 +79,9 @@ def _sample_posterior( def plot_convergence( idata: Any, - var_name: Optional[str] = None, + var_name: str | None = None, kind: str = "ecdf", - figsize: Optional[tuple[float, float]] = None, + figsize: tuple[float, float] | None = None, ax=None, ) -> None: """ @@ -114,23 +115,23 @@ def plot_convergence( def plot_ice( bartrv: Variable, X: npt.NDArray, - Y: Optional[npt.NDArray] = None, - var_idx: Optional[list[int]] = None, - var_discrete: Optional[list[int]] = None, - func: Optional[Callable] = None, - centered: Optional[bool] = True, + Y: npt.NDArray | None = None, + var_idx: list[int] | None = None, + var_discrete: list[int] | None = None, + func: Callable | None = None, + centered: bool | None = True, samples: int = 100, instances: int = 30, - random_seed: Optional[int] = None, + random_seed: int | None = None, sharey: bool = True, smooth: bool = True, grid: str = "long", color="C0", color_mean: str = "C0", alpha: float = 0.1, - figsize: Optional[tuple[float, float]] = None, - smooth_kwargs: Optional[dict[str, Any]] = None, - ax: Optional[plt.Axes] = None, + figsize: tuple[float, float] | None = None, + smooth_kwargs: dict[str, Any] | None = None, + ax: plt.Axes | None = None, ) -> list[plt.Axes]: """ Individual conditional expectation plot. @@ -258,24 +259,24 @@ def identity(x): def plot_pdp( bartrv: Variable, X: npt.NDArray, - Y: Optional[npt.NDArray] = None, + Y: npt.NDArray | None = None, xs_interval: str = "quantiles", - xs_values: Optional[Union[int, list[float]]] = None, - var_idx: Optional[list[int]] = None, - var_discrete: Optional[list[int]] = None, - func: Optional[Callable] = None, + xs_values: int | list[float] | None = None, + var_idx: list[int] | None = None, + var_discrete: list[int] | None = None, + func: Callable | None = None, samples: int = 200, ref_line: bool = True, - random_seed: Optional[int] = None, + random_seed: int | None = None, sharey: bool = True, smooth: bool = True, grid: str = "long", color="C0", color_mean: str = "C0", alpha: float = 0.1, - figsize: Optional[tuple[float, float]] = None, - smooth_kwargs: Optional[dict[str, Any]] = None, - ax: Optional[plt.Axes] = None, + figsize: tuple[float, float] | None = None, + smooth_kwargs: dict[str, Any] | None = None, + ax: plt.Axes = None, ) -> list[plt.Axes]: """ Partial dependence plot. @@ -425,8 +426,8 @@ def _create_figure_axes( var_idx: list[int], grid: str = "long", sharey: bool = True, - figsize: Optional[tuple[float, float]] = None, - ax: Optional[plt.Axes] = None, + figsize: tuple[float, float] | None = None, + ax: plt.Axes | None = None, ) -> tuple[plt.Figure, list[plt.Axes], int]: """ Create and return the figure and axes objects for plotting the variables. @@ -506,11 +507,11 @@ def _get_axes(grid, n_plots, sharex, sharey, figsize): def _prepare_plot_data( X: npt.NDArray, - Y: Optional[npt.NDArray] = None, + Y: npt.NDArray | None = None, xs_interval: str = "quantiles", - xs_values: Optional[Union[int, list[float]]] = None, - var_idx: Optional[list[int]] = None, - var_discrete: Optional[list[int]] = None, + xs_values: int | list[float] | None = None, + var_idx: list[int] | None = None, + var_discrete: list[int] | None = None, ) -> tuple[ npt.NDArray, list[str], @@ -519,7 +520,7 @@ def _prepare_plot_data( list[int], list[int], str, - Union[int, None, list[float]], + int | None | list[float], ]: """ Prepare data for plotting. @@ -600,7 +601,7 @@ def _prepare_plot_data( def _create_pdp_data( X: npt.NDArray, xs_interval: str, - xs_values: Optional[Union[int, list[float]]] = None, + xs_values: int | list[float] | None = None, ) -> npt.NDArray: """ Create data for partial dependence plot. @@ -636,7 +637,7 @@ def _smooth_mean( new_x: npt.NDArray, p_di: npt.NDArray, kind: str = "neutral", - smooth_kwargs: Optional[dict[str, Any]] = None, + smooth_kwargs: dict[str, Any] | None = None, ) -> tuple[np.ndarray, np.ndarray]: """ Smooth the mean data for plotting. @@ -805,7 +806,7 @@ def compute_variable_importance( # noqa: PLR0915 PLR0912 fixed: int = 0, samples: int = 50, random_seed: int | None = None, -) -> dict[str, object]: +) -> dict[str, npt.NDArray]: """ Estimates variable importance from the BART-posterior. @@ -1026,11 +1027,11 @@ def vi_to_kulprit(vi_results: dict) -> list[list[str]]: def plot_variable_importance( vi_results: dict, - submodels: Optional[Union[list[int], np.ndarray, tuple[int, ...]]] = None, - labels: Optional[list[str]] = None, - figsize: Optional[tuple[float, float]] = None, - plot_kwargs: Optional[dict[str, Any]] = None, - ax: Optional[plt.Axes] = None, + submodels: list[int] | np.ndarray | tuple[int, ...] | None = None, + labels: list[str] | None = None, + figsize: tuple[float, float] | None = None, + plot_kwargs: dict[str, Any] | None = None, + ax: plt.Axes | None = None, ): """ Estimates variable importance from the BART-posterior. @@ -1128,13 +1129,13 @@ def plot_variable_importance( def plot_scatter_submodels( vi_results: dict, - func: Optional[Callable] = None, - submodels: Optional[Union[list[int], np.ndarray]] = None, + func: Callable | None = None, + submodels: list[int] | np.ndarray | None = None, grid: str = "long", - labels: Optional[list[str]] = None, - figsize: Optional[tuple[float, float]] = None, - plot_kwargs: Optional[dict[str, Any]] = None, - ax: Optional[plt.Axes] = None, + labels: list[str] | None = None, + figsize: tuple[float, float] | None = None, + plot_kwargs: dict[str, Any] | None = None, + ax: plt.Axes | None = None, ) -> list[plt.Axes]: """ Plot submodel's predictions against reference-model's predictions. diff --git a/pyproject.toml b/pyproject.toml index 2773123..2afa2c4 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -37,7 +37,6 @@ pyupgrade = 1 [tool.mypy] files = "pymc_bart/*.py" -plugins = "numpy.typing.mypy_plugin" [tool.mypy-matplotlib] ignore_missing_imports = true