Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
93 changes: 47 additions & 46 deletions pymc_bart/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,8 @@
"""Utility function for variable selection and bart interpretability."""

import warnings
from typing import Any, Callable, Optional, Union
from collections.abc import Callable
from typing import Any, TypeVar

import matplotlib.pyplot as plt
import numpy as np
Expand All @@ -18,15 +19,15 @@

from .tree import Tree

TensorLike = Union[npt.NDArray, pt.TensorVariable]
TensorLike = TypeVar("TensorLike", npt.NDArray, pt.TensorVariable)


def _sample_posterior(
all_trees: list[list[Tree]],
X: TensorLike,
rng: np.random.Generator,
size: Optional[Union[int, tuple[int, ...]]] = None,
excluded: Optional[list[int]] = None,
size: int | tuple[int, ...] | None = None,
excluded: list[int] | None = None,
shape: int = 1,
) -> npt.NDArray:
"""
Expand All @@ -51,7 +52,7 @@ def _sample_posterior(
X = X.eval()

if size is None:
size_iter: Union[list, tuple] = (1,)
size_iter: list | tuple = (1,)
elif isinstance(size, int):
size_iter = [size]
else:
Expand All @@ -78,9 +79,9 @@ def _sample_posterior(

def plot_convergence(
idata: Any,
var_name: Optional[str] = None,
var_name: str | None = None,
kind: str = "ecdf",
figsize: Optional[tuple[float, float]] = None,
figsize: tuple[float, float] | None = None,
ax=None,
) -> None:
"""
Expand Down Expand Up @@ -114,23 +115,23 @@ def plot_convergence(
def plot_ice(
bartrv: Variable,
X: npt.NDArray,
Y: Optional[npt.NDArray] = None,
var_idx: Optional[list[int]] = None,
var_discrete: Optional[list[int]] = None,
func: Optional[Callable] = None,
centered: Optional[bool] = True,
Y: npt.NDArray | None = None,
var_idx: list[int] | None = None,
var_discrete: list[int] | None = None,
func: Callable | None = None,
centered: bool | None = True,
samples: int = 100,
instances: int = 30,
random_seed: Optional[int] = None,
random_seed: int | None = None,
sharey: bool = True,
smooth: bool = True,
grid: str = "long",
color="C0",
color_mean: str = "C0",
alpha: float = 0.1,
figsize: Optional[tuple[float, float]] = None,
smooth_kwargs: Optional[dict[str, Any]] = None,
ax: Optional[plt.Axes] = None,
figsize: tuple[float, float] | None = None,
smooth_kwargs: dict[str, Any] | None = None,
ax: plt.Axes | None = None,
) -> list[plt.Axes]:
"""
Individual conditional expectation plot.
Expand Down Expand Up @@ -258,24 +259,24 @@ def identity(x):
def plot_pdp(
bartrv: Variable,
X: npt.NDArray,
Y: Optional[npt.NDArray] = None,
Y: npt.NDArray | None = None,
xs_interval: str = "quantiles",
xs_values: Optional[Union[int, list[float]]] = None,
var_idx: Optional[list[int]] = None,
var_discrete: Optional[list[int]] = None,
func: Optional[Callable] = None,
xs_values: int | list[float] | None = None,
var_idx: list[int] | None = None,
var_discrete: list[int] | None = None,
func: Callable | None = None,
samples: int = 200,
ref_line: bool = True,
random_seed: Optional[int] = None,
random_seed: int | None = None,
sharey: bool = True,
smooth: bool = True,
grid: str = "long",
color="C0",
color_mean: str = "C0",
alpha: float = 0.1,
figsize: Optional[tuple[float, float]] = None,
smooth_kwargs: Optional[dict[str, Any]] = None,
ax: Optional[plt.Axes] = None,
figsize: tuple[float, float] | None = None,
smooth_kwargs: dict[str, Any] | None = None,
ax: plt.Axes = None,
) -> list[plt.Axes]:
"""
Partial dependence plot.
Expand Down Expand Up @@ -425,8 +426,8 @@ def _create_figure_axes(
var_idx: list[int],
grid: str = "long",
sharey: bool = True,
figsize: Optional[tuple[float, float]] = None,
ax: Optional[plt.Axes] = None,
figsize: tuple[float, float] | None = None,
ax: plt.Axes | None = None,
) -> tuple[plt.Figure, list[plt.Axes], int]:
"""
Create and return the figure and axes objects for plotting the variables.
Expand Down Expand Up @@ -506,11 +507,11 @@ def _get_axes(grid, n_plots, sharex, sharey, figsize):

def _prepare_plot_data(
X: npt.NDArray,
Y: Optional[npt.NDArray] = None,
Y: npt.NDArray | None = None,
xs_interval: str = "quantiles",
xs_values: Optional[Union[int, list[float]]] = None,
var_idx: Optional[list[int]] = None,
var_discrete: Optional[list[int]] = None,
xs_values: int | list[float] | None = None,
var_idx: list[int] | None = None,
var_discrete: list[int] | None = None,
) -> tuple[
npt.NDArray,
list[str],
Expand All @@ -519,7 +520,7 @@ def _prepare_plot_data(
list[int],
list[int],
str,
Union[int, None, list[float]],
int | None | list[float],
]:
"""
Prepare data for plotting.
Expand Down Expand Up @@ -600,7 +601,7 @@ def _prepare_plot_data(
def _create_pdp_data(
X: npt.NDArray,
xs_interval: str,
xs_values: Optional[Union[int, list[float]]] = None,
xs_values: int | list[float] | None = None,
) -> npt.NDArray:
"""
Create data for partial dependence plot.
Expand Down Expand Up @@ -636,7 +637,7 @@ def _smooth_mean(
new_x: npt.NDArray,
p_di: npt.NDArray,
kind: str = "neutral",
smooth_kwargs: Optional[dict[str, Any]] = None,
smooth_kwargs: dict[str, Any] | None = None,
) -> tuple[np.ndarray, np.ndarray]:
"""
Smooth the mean data for plotting.
Expand Down Expand Up @@ -805,7 +806,7 @@ def compute_variable_importance( # noqa: PLR0915 PLR0912
fixed: int = 0,
samples: int = 50,
random_seed: int | None = None,
) -> dict[str, object]:
) -> dict[str, npt.NDArray]:
"""
Estimates variable importance from the BART-posterior.

Expand Down Expand Up @@ -1026,11 +1027,11 @@ def vi_to_kulprit(vi_results: dict) -> list[list[str]]:

def plot_variable_importance(
vi_results: dict,
submodels: Optional[Union[list[int], np.ndarray, tuple[int, ...]]] = None,
labels: Optional[list[str]] = None,
figsize: Optional[tuple[float, float]] = None,
plot_kwargs: Optional[dict[str, Any]] = None,
ax: Optional[plt.Axes] = None,
submodels: list[int] | np.ndarray | tuple[int, ...] | None = None,
labels: list[str] | None = None,
figsize: tuple[float, float] | None = None,
plot_kwargs: dict[str, Any] | None = None,
ax: plt.Axes | None = None,
):
"""
Estimates variable importance from the BART-posterior.
Expand Down Expand Up @@ -1128,13 +1129,13 @@ def plot_variable_importance(

def plot_scatter_submodels(
vi_results: dict,
func: Optional[Callable] = None,
submodels: Optional[Union[list[int], np.ndarray]] = None,
func: Callable | None = None,
submodels: list[int] | np.ndarray | None = None,
grid: str = "long",
labels: Optional[list[str]] = None,
figsize: Optional[tuple[float, float]] = None,
plot_kwargs: Optional[dict[str, Any]] = None,
ax: Optional[plt.Axes] = None,
labels: list[str] | None = None,
figsize: tuple[float, float] | None = None,
plot_kwargs: dict[str, Any] | None = None,
ax: plt.Axes | None = None,
) -> list[plt.Axes]:
"""
Plot submodel's predictions against reference-model's predictions.
Expand Down
1 change: 0 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,6 @@ pyupgrade = 1

[tool.mypy]
files = "pymc_bart/*.py"
plugins = "numpy.typing.mypy_plugin"

[tool.mypy-matplotlib]
ignore_missing_imports = true
Expand Down
Loading