diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index fe00024..4f55bc1 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -12,14 +12,14 @@ ci: repos: - repo: https://github.com/astral-sh/ruff-pre-commit - rev: v0.8.3 + rev: v0.8.4 hooks: - id: ruff args: ["--fix", "--output-format=full"] - id: ruff-format args: ["--line-length=100"] - repo: https://github.com/pre-commit/mirrors-mypy - rev: v1.13.0 + rev: v1.14.0 hooks: - id: mypy args: [--ignore-missing-imports] diff --git a/mypy.ini b/mypy.ini deleted file mode 100644 index 56088d7..0000000 --- a/mypy.ini +++ /dev/null @@ -1,15 +0,0 @@ -[mypy] -files = pymc_bart/*.py -plugins = numpy.typing.mypy_plugin - -[mypy-matplotlib.*] -ignore_missing_imports = True - -[mypy-numba.*] -ignore_missing_imports = True - -[mypy-pymc.*] -ignore_missing_imports = True - -[mypy-scipy.*] -ignore_missing_imports = True diff --git a/pymc_bart/bart.py b/pymc_bart/bart.py index decb499..ac2be35 100644 --- a/pymc_bart/bart.py +++ b/pymc_bart/bart.py @@ -132,7 +132,7 @@ def __new__( alpha: float = 0.95, beta: float = 2.0, response: str = "constant", - split_prior: Optional[npt.NDArray[np.float64]] = None, + split_prior: Optional[npt.NDArray] = None, split_rules: Optional[list[SplitRule]] = None, separate_trees: Optional[bool] = False, **kwargs, @@ -203,9 +203,7 @@ def get_moment(cls, rv, size, *rv_inputs): return mean -def preprocess_xy( - X: TensorLike, Y: TensorLike -) -> tuple[npt.NDArray[np.float64], npt.NDArray[np.float64]]: +def preprocess_xy(X: TensorLike, Y: TensorLike) -> tuple[npt.NDArray, npt.NDArray]: if isinstance(Y, (Series, DataFrame)): Y = Y.to_numpy() if isinstance(X, (Series, DataFrame)): diff --git a/pymc_bart/pgbart.py b/pymc_bart/pgbart.py index 1505f15..014313a 100644 --- a/pymc_bart/pgbart.py +++ b/pymc_bart/pgbart.py @@ -16,6 +16,8 @@ import numpy as np import numpy.typing as npt +import pymc as pm +import pytensor.tensor as pt from numba import njit from pymc.initial_point import PointType from pymc.model import Model, modelcontext @@ -120,15 +122,15 @@ class PGBART(ArrayStepShared): "tune": (bool, []), } - def __init__( # noqa: PLR0915 + def __init__( # noqa: PLR0912, PLR0915 self, - vars=None, # pylint: disable=redefined-builtin + vars: list[pm.Distribution] | None = None, num_particles: int = 10, batch: tuple[float, float] = (0.1, 0.1), model: Optional[Model] = None, initial_point: PointType | None = None, - compile_kwargs: dict | None = None, # pylint: disable=unused-argument - ): + compile_kwargs: dict | None = None, + ) -> None: model = modelcontext(model) if initial_point is None: initial_point = model.initial_point() @@ -137,6 +139,10 @@ def __init__( # noqa: PLR0915 else: vars = [model.rvs_to_values.get(var, var) for var in vars] vars = inputvars(vars) + + if vars is None: + raise ValueError("Unable to find variables to sample") + value_bart = vars[0] self.bart = model.values_to_rvs[value_bart].owner.op @@ -325,7 +331,7 @@ def normalize(self, particles: list[ParticleTree]) -> float: return wei / wei.sum() def resample( - self, particles: list[ParticleTree], normalized_weights: npt.NDArray[np.float64] + self, particles: list[ParticleTree], normalized_weights: npt.NDArray ) -> list[ParticleTree]: """ Use systematic resample for all but the first particle @@ -347,7 +353,7 @@ def resample( return particles def get_particle_tree( - self, particles: list[ParticleTree], normalized_weights: npt.NDArray[np.float64] + self, particles: list[ParticleTree], normalized_weights: npt.NDArray ) -> tuple[ParticleTree, Tree]: """ Sample a new particle and associated tree @@ -359,7 +365,7 @@ def get_particle_tree( return new_particle, new_particle.tree - def systematic(self, normalized_weights: npt.NDArray[np.float64]) -> npt.NDArray[np.int_]: + def systematic(self, normalized_weights: npt.NDArray) -> npt.NDArray[np.int_]: """ Systematic resampling. @@ -395,7 +401,7 @@ def update_weight(self, particle: ParticleTree, odim: int) -> None: particle.log_weight = new_likelihood @staticmethod - def competence(var, has_grad): + def competence(var: pm.Distribution, has_grad: bool) -> Competence: """PGBART is only suitable for BART distributions.""" dist = getattr(var.owner, "op", None) if isinstance(dist, BARTRV): @@ -406,12 +412,12 @@ def competence(var, has_grad): class RunningSd: """Welford's online algorithm for computing the variance/standard deviation""" - def __init__(self, shape: tuple) -> None: + def __init__(self, shape: tuple[int, ...]) -> None: self.count = 0 # number of data points self.mean = np.zeros(shape) # running mean self.m_2 = np.zeros(shape) # running second moment - def update(self, new_value: npt.NDArray[np.float64]) -> Union[float, npt.NDArray[np.float64]]: + def update(self, new_value: npt.NDArray) -> Union[float, npt.NDArray]: self.count = self.count + 1 self.mean, self.m_2, std = _update(self.count, self.mean, self.m_2, new_value) return fast_mean(std) @@ -420,10 +426,10 @@ def update(self, new_value: npt.NDArray[np.float64]) -> Union[float, npt.NDArray @njit def _update( count: int, - mean: npt.NDArray[np.float64], - m_2: npt.NDArray[np.float64], - new_value: npt.NDArray[np.float64], -) -> tuple[npt.NDArray[np.float64], npt.NDArray[np.float64], Union[float, npt.NDArray[np.float64]]]: + mean: npt.NDArray, + m_2: npt.NDArray, + new_value: npt.NDArray, +) -> tuple[npt.NDArray, npt.NDArray, Union[float, npt.NDArray]]: delta = new_value - mean mean += delta / count delta2 = new_value - mean @@ -434,7 +440,7 @@ def _update( class SampleSplittingVariable: - def __init__(self, alpha_vec: npt.NDArray[np.float64]) -> None: + def __init__(self, alpha_vec: npt.NDArray) -> None: """ Sample splitting variables proportional to `alpha_vec`. @@ -547,16 +553,16 @@ def filter_missing_values(available_splitting_values, idx_data_points, missing_d def draw_leaf_value( - y_mu_pred: npt.NDArray[np.float64], - x_mu: npt.NDArray[np.float64], + y_mu_pred: npt.NDArray, + x_mu: npt.NDArray, m: int, - norm: npt.NDArray[np.float64], + norm: npt.NDArray, shape: int, response: str, -) -> tuple[npt.NDArray[np.float64], Optional[npt.NDArray[np.float64]]]: +) -> tuple[npt.NDArray, Optional[npt.NDArray]]: """Draw Gaussian distributed leaf values.""" linear_params = None - mu_mean = np.empty(shape) + mu_mean: npt.NDArray if y_mu_pred.size == 0: return np.zeros(shape), linear_params @@ -571,7 +577,7 @@ def draw_leaf_value( @njit -def fast_mean(ari: npt.NDArray[np.float64]) -> Union[float, npt.NDArray[np.float64]]: +def fast_mean(ari: npt.NDArray) -> Union[float, npt.NDArray]: """Use Numba to speed up the computation of the mean.""" if ari.ndim == 1: count = ari.shape[0] @@ -590,11 +596,11 @@ def fast_mean(ari: npt.NDArray[np.float64]) -> Union[float, npt.NDArray[np.float @njit def fast_linear_fit( - x: npt.NDArray[np.float64], - y: npt.NDArray[np.float64], + x: npt.NDArray, + y: npt.NDArray, m: int, - norm: npt.NDArray[np.float64], -) -> tuple[npt.NDArray[np.float64], list[npt.NDArray[np.float64]]]: + norm: npt.NDArray, +) -> tuple[npt.NDArray, list[npt.NDArray]]: n = len(x) y = y / m + np.expand_dims(norm, axis=1) @@ -678,17 +684,17 @@ def update(self): @njit def inverse_cdf( - single_uniform: npt.NDArray[np.float64], normalized_weights: npt.NDArray[np.float64] + single_uniform: npt.NDArray, normalized_weights: npt.NDArray ) -> npt.NDArray[np.int_]: """ Inverse CDF algorithm for a finite distribution. Parameters ---------- - single_uniform: npt.NDArray[np.float64] + single_uniform: npt.NDArray Ordered points in [0,1] - normalized_weights: npt.NDArray[np.float64]) + normalized_weights: npt.NDArray) Normalized weights Returns @@ -711,7 +717,7 @@ def inverse_cdf( @njit -def jitter_duplicated(array: npt.NDArray[np.float64], std: float) -> npt.NDArray[np.float64]: +def jitter_duplicated(array: npt.NDArray, std: float) -> npt.NDArray: """ Jitter duplicated values. """ @@ -727,12 +733,17 @@ def jitter_duplicated(array: npt.NDArray[np.float64], std: float) -> npt.NDArray @njit -def are_whole_number(array: npt.NDArray[np.float64]) -> np.bool_: +def are_whole_number(array: npt.NDArray) -> np.bool_: """Check if all values in array are whole numbers""" return np.all(np.mod(array[~np.isnan(array)], 1) == 0) -def logp(point, out_vars, vars, shared): # pylint: disable=redefined-builtin +def logp( + point, + out_vars: list[pm.Distribution], + vars: list[pm.Distribution], + shared: list[pt.TensorVariable], +): """Compile PyTensor function of the model and the input and output variables. Parameters diff --git a/pymc_bart/tree.py b/pymc_bart/tree.py index 7655175..61e5050 100644 --- a/pymc_bart/tree.py +++ b/pymc_bart/tree.py @@ -28,7 +28,7 @@ class Node: Attributes ---------- - value : npt.NDArray[np.float64] + value : npt.NDArray idx_data_points : Optional[npt.NDArray[np.int_]] idx_split_variable : int linear_params: Optional[list[float]] = None @@ -38,11 +38,11 @@ class Node: def __init__( self, - value: npt.NDArray[np.float64] = np.array([-1.0]), + value: npt.NDArray = np.array([-1.0]), nvalue: int = 0, idx_data_points: Optional[npt.NDArray[np.int_]] = None, idx_split_variable: int = -1, - linear_params: Optional[list[npt.NDArray[np.float64]]] = None, + linear_params: Optional[list[npt.NDArray]] = None, ) -> None: self.value = value self.nvalue = nvalue @@ -53,11 +53,11 @@ def __init__( @classmethod def new_leaf_node( cls, - value: npt.NDArray[np.float64], + value: npt.NDArray, nvalue: int = 0, idx_data_points: Optional[npt.NDArray[np.int_]] = None, idx_split_variable: int = -1, - linear_params: Optional[list[npt.NDArray[np.float64]]] = None, + linear_params: Optional[list[npt.NDArray]] = None, ) -> "Node": return cls( value=value, @@ -101,7 +101,7 @@ class Tree: The dictionary's keys are integers that represent the nodes position. The dictionary's values are objects of type Node that represent the split and leaf nodes of the tree itself. - output: Optional[npt.NDArray[np.float64]] + output: Optional[npt.NDArray] Array of shape number of observations, shape split_rules : list[SplitRule] List of SplitRule objects, one per column in input data. @@ -122,7 +122,7 @@ class Tree: def __init__( self, tree_structure: dict[int, Node], - output: npt.NDArray[np.float64], + output: npt.NDArray, split_rules: list[SplitRule], idx_leaf_nodes: Optional[list[int]] = None, ) -> None: @@ -134,7 +134,7 @@ def __init__( @classmethod def new_tree( cls, - leaf_node_value: npt.NDArray[np.float64], + leaf_node_value: npt.NDArray, idx_data_points: Optional[npt.NDArray[np.int_]], num_observations: int, shape: int, @@ -190,7 +190,7 @@ def grow_leaf_node( self, current_node: Node, selected_predictor: int, - split_value: npt.NDArray[np.float64], + split_value: npt.NDArray, index_leaf_node: int, ) -> None: current_node.value = split_value @@ -222,7 +222,7 @@ def get_split_variables(self) -> Generator[int, None, None]: if node.is_split_node(): yield node.idx_split_variable - def _predict(self) -> npt.NDArray[np.float64]: + def _predict(self) -> npt.NDArray: output = self.output if self.idx_leaf_nodes is not None: @@ -233,23 +233,23 @@ def _predict(self) -> npt.NDArray[np.float64]: def predict( self, - x: npt.NDArray[np.float64], + x: npt.NDArray, excluded: Optional[list[int]] = None, shape: int = 1, - ) -> npt.NDArray[np.float64]: + ) -> npt.NDArray: """ Predict output of tree for an (un)observed point x. Parameters ---------- - x : npt.NDArray[np.float64] + x : npt.NDArray Unobserved point excluded: Optional[list[int]] Indexes of the variables to exclude when computing predictions Returns ------- - npt.NDArray[np.float64] + npt.NDArray Value of the leaf value where the unobserved point lies. """ if excluded is None: @@ -259,16 +259,16 @@ def predict( def _traverse_tree( self, - X: npt.NDArray[np.float64], + X: npt.NDArray, excluded: Optional[list[int]] = None, shape: Union[int, tuple[int, ...]] = 1, - ) -> npt.NDArray[np.float64]: + ) -> npt.NDArray: """ Traverse the tree starting from the root node given an (un)observed point. Parameters ---------- - X : npt.NDArray[np.float64] + X : npt.NDArray (Un)observed point(s) node_index : int Index of the node to start the traversal from @@ -279,14 +279,16 @@ def _traverse_tree( Returns ------- - npt.NDArray[np.float64] + npt.NDArray Leaf node value or mean of leaf node values """ x_shape = (1,) if len(X.shape) == 1 else X.shape[:-1] nd_dims = (...,) + (None,) * len(x_shape) - stack = [(0, np.ones(x_shape), 0)] # (node_index, weight, idx_split_variable) initial state + stack: list[tuple[int, npt.NDArray, int]] = [ + (0, np.ones(x_shape), 0) + ] # (node_index, weight, idx_split_variable) initial state p_d = ( np.zeros(shape + x_shape) if isinstance(shape, tuple) else np.zeros((shape,) + x_shape) ) @@ -309,9 +311,19 @@ def _traverse_tree( ) if excluded is not None and idx_split_variable in excluded: prop_nvalue_left = self.get_node(left_node_index).nvalue / node.nvalue - stack.append((left_node_index, weights * prop_nvalue_left, idx_split_variable)) stack.append( - (right_node_index, weights * (1 - prop_nvalue_left), idx_split_variable) + ( + left_node_index, + weights * prop_nvalue_left, + idx_split_variable, + ) + ) + stack.append( + ( + right_node_index, + weights * (1 - prop_nvalue_left), + idx_split_variable, + ) ) else: to_left = ( @@ -328,14 +340,14 @@ def _traverse_tree( return p_d def _traverse_leaf_values( - self, leaf_values: list[npt.NDArray[np.float64]], leaf_n_values: list[int], node_index: int + self, leaf_values: list[npt.NDArray], leaf_n_values: list[int], node_index: int ) -> None: """ Traverse the tree appending leaf values starting from a particular node. Parameters ---------- - leaf_values : list[npt.NDArray[np.float64]] + leaf_values : list[npt.NDArray] node_index : int """ node = self.get_node(node_index) diff --git a/pymc_bart/utils.py b/pymc_bart/utils.py index d9d5241..58d14b8 100644 --- a/pymc_bart/utils.py +++ b/pymc_bart/utils.py @@ -17,7 +17,7 @@ from .tree import Tree -TensorLike = Union[npt.NDArray[np.float64], pt.TensorVariable] +TensorLike = Union[npt.NDArray, pt.TensorVariable] def _sample_posterior( @@ -27,7 +27,7 @@ def _sample_posterior( size: Optional[Union[int, tuple[int, ...]]] = None, excluded: Optional[list[int]] = None, shape: int = 1, -) -> npt.NDArray[np.float64]: +) -> npt.NDArray: """ Generate samples from the BART-posterior. @@ -139,8 +139,8 @@ def plot_convergence( def plot_ice( bartrv: Variable, - X: npt.NDArray[np.float64], - Y: Optional[npt.NDArray[np.float64]] = None, + X: npt.NDArray, + Y: Optional[npt.NDArray] = None, var_idx: Optional[list[int]] = None, var_discrete: Optional[list[int]] = None, func: Optional[Callable] = None, @@ -165,9 +165,9 @@ def plot_ice( ---------- bartrv : BART Random Variable BART variable once the model that include it has been fitted. - X : npt.NDArray[np.float64] + X : npt.NDArray The covariate matrix. - Y : Optional[npt.NDArray[np.float64]], by default None. + Y : Optional[npt.NDArray], by default None. The response vector. var_idx : Optional[list[int]], by default None. List of the indices of the covariate for which to compute the pdp or ice. @@ -283,8 +283,8 @@ def identity(x): def plot_pdp( bartrv: Variable, - X: npt.NDArray[np.float64], - Y: Optional[npt.NDArray[np.float64]] = None, + X: npt.NDArray, + Y: Optional[npt.NDArray] = None, xs_interval: str = "quantiles", xs_values: Optional[Union[int, list[float]]] = None, var_idx: Optional[list[int]] = None, @@ -310,9 +310,9 @@ def plot_pdp( ---------- bartrv : BART Random Variable BART variable once the model that include it has been fitted. - X : npt.NDArray[np.float64] + X : npt.NDArray The covariate matrix. - Y : Optional[npt.NDArray[np.float64]], by default None. + Y : Optional[npt.NDArray], by default None. The response vector. xs_interval : str Method used to compute the values X used to evaluate the predicted function. "linear", @@ -526,14 +526,14 @@ def _get_axes(grid, n_plots, sharex, sharey, figsize): def _prepare_plot_data( - X: npt.NDArray[np.float64], - Y: Optional[npt.NDArray[np.float64]] = None, + X: npt.NDArray, + Y: Optional[npt.NDArray] = None, xs_interval: str = "quantiles", xs_values: Optional[Union[int, list[float]]] = None, var_idx: Optional[list[int]] = None, var_discrete: Optional[list[int]] = None, ) -> tuple[ - npt.NDArray[np.float64], + npt.NDArray, list[str], str, list[int], @@ -619,10 +619,10 @@ def _prepare_plot_data( def _create_pdp_data( - X: npt.NDArray[np.float64], + X: npt.NDArray, xs_interval: str, xs_values: Optional[Union[int, list[float]]] = None, -) -> npt.NDArray[np.float64]: +) -> npt.NDArray: """ Create data for partial dependence plot. @@ -637,7 +637,7 @@ def _create_pdp_data( Returns ------- - npt.NDArray[np.float64] + npt.NDArray A 2D array for the fake_X data. """ if xs_interval == "insample": @@ -654,8 +654,8 @@ def _create_pdp_data( def _smooth_mean( - new_x: npt.NDArray[np.float64], - p_di: npt.NDArray[np.float64], + new_x: npt.NDArray, + p_di: npt.NDArray, kind: str = "pdp", smooth_kwargs: Optional[dict[str, Any]] = None, ) -> tuple[np.ndarray, np.ndarray]: @@ -701,7 +701,7 @@ def plot_variable_inclusion(idata, X, labels=None, figsize=None, plot_kwargs=Non ---------- idata : InferenceData InferenceData containing a collection of BART_trees in sample_stats group - X : npt.NDArray[np.float64] + X : npt.NDArray The covariate matrix. labels : Optional[list[str]] List of the names of the covariates. If X is a DataFrame the names of the covariables will @@ -767,7 +767,7 @@ def plot_variable_inclusion(idata, X, labels=None, figsize=None, plot_kwargs=Non def compute_variable_importance( # noqa: PLR0915 PLR0912 idata: az.InferenceData, bartrv: Variable, - X: npt.NDArray[np.float64], + X: npt.NDArray, method: str = "VI", fixed: int = 0, samples: int = 50, @@ -782,7 +782,7 @@ def compute_variable_importance( # noqa: PLR0915 PLR0912 InferenceData containing a collection of BART_trees in sample_stats group bartrv : BART Random Variable BART variable once the model that include it has been fitted. - X : npt.NDArray[np.float64] + X : npt.NDArray The covariate matrix. method : str Method used to rank variables. Available options are "VI" (default), "backward" @@ -826,9 +826,9 @@ def compute_variable_importance( # noqa: PLR0915 PLR0912 else: labels = np.arange(n_vars).astype(str) - r2_mean = np.zeros(n_vars) - r2_hdi = np.zeros((n_vars, 2)) - preds = np.zeros((n_vars, samples, *bartrv.eval().T.shape)) + r2_mean: npt.NDArray = np.zeros(n_vars) + r2_hdi: npt.NDArray = np.zeros((n_vars, 2)) + preds: npt.NDArray = np.zeros((n_vars, samples, *bartrv.eval().T.shape)) if method == "backward_VI": if fixed >= n_vars: @@ -848,7 +848,7 @@ def compute_variable_importance( # noqa: PLR0915 PLR0912 idxs = np.argsort( idata["sample_stats"]["variable_inclusion"].mean(("chain", "draw")).values ) - subsets = [idxs[:-i].tolist() for i in range(1, len(idxs))] + subsets: list[list[int]] = [list(idxs[:-i]) for i in range(1, len(idxs))] subsets.append(None) # type: ignore if method == "backward_VI": diff --git a/pyproject.toml b/pyproject.toml index f8f3e7a..4a2273d 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -33,3 +33,20 @@ exclude_lines = [ isort = 1 black = 1 pyupgrade = 1 + + +[tool.mypy] +files = "pymc_bart/*.py" +plugins = "numpy.typing.mypy_plugin" + +[tool.mypy-matplotlib] +ignore_missing_imports = true + +[tool.mypy-numba] +ignore_missing_imports = true + +[tool.mypy-pymc] +ignore_missing_imports = true + +[tool.mypy-scipy] +ignore_missing_imports = true