diff --git a/pymc_bart/utils.py b/pymc_bart/utils.py index e8c60bb..31cc28f 100644 --- a/pymc_bart/utils.py +++ b/pymc_bart/utils.py @@ -824,10 +824,14 @@ def compute_variable_importance( # noqa: PLR0915 PLR0912 else: shape = bartrv.eval().shape[0] + n_vars = X.shape[1] + if hasattr(X, "columns") and hasattr(X, "to_numpy"): + labels = X.columns X = X.to_numpy() + else: + labels = np.arange(n_vars).astype(str) - n_vars = X.shape[1] r2_mean = np.zeros(n_vars) r2_hdi = np.zeros((n_vars, 2)) preds = np.zeros((n_vars, samples, bartrv.eval().shape[0])) @@ -947,6 +951,7 @@ def compute_variable_importance( # noqa: PLR0915 PLR0912 vi_results = { "indices": indices, + "labels": labels[indices], "r2_mean": r2_mean, "r2_hdi": r2_hdi, "preds": preds, @@ -957,7 +962,6 @@ def compute_variable_importance( # noqa: PLR0915 PLR0912 def plot_variable_importance( vi_results: dict, - X: npt.NDArray[np.float64], labels=None, figsize=None, plot_kwargs: Optional[Dict[str, Any]] = None, @@ -1008,19 +1012,13 @@ def plot_variable_importance( if figsize is None: figsize = (8, 3) - if hasattr(X, "columns") and hasattr(X, "to_numpy"): - labels = X.columns - X = X.to_numpy() - if ax is None: _, ax = plt.subplots(1, 1, figsize=figsize) if labels is None: - labels = np.arange(n_vars).astype(str) - else: - labels = np.asarray(labels) + labels = vi_results["labels"] - new_labels = ["+ " + ele if index != 0 else ele for index, ele in enumerate(labels[indices])] + labels = ["+ " + ele if index != 0 else ele for index, ele in enumerate(labels)] r_2_ref = np.array([pearsonr2(preds_all[j], preds_all[j + 1]) for j in range(samples - 1)]) @@ -1048,7 +1046,7 @@ def plot_variable_importance( ) ax.set_xticks( ticks, - new_labels, + labels, rotation=plot_kwargs.get("rotation", 0), ) ax.set_ylabel("R²", rotation=0, labelpad=12) @@ -1058,15 +1056,57 @@ def plot_variable_importance( return ax -def plot_scatter_submodels(vi_results, func=None, grid="long", axes=None): +def plot_scatter_submodels( + vi_results: dict, + func: Optional[Callable] = None, + grid: str = "long", + labels=None, + figsize: Optional[Tuple[float, float]] = None, + plot_kwargs: Optional[Dict[str, Any]] = None, + axes: Optional[plt.Axes] = None, +): + """ + Plot submodel's predictions against reference-model's predictions. + + Parameters + ---------- + vi_results: Dictionary + Dictionary computed with `compute_variable_importance` + func : Optional[Callable], by default None. + Arbitrary function to apply to the predictions. Defaults to the identity function. + grid : str or tuple + How to arrange the subplots. Defaults to "long", one subplot below the other. + Other options are "wide", one subplot next to each other or a tuple indicating the number + of rows and columns. + labels : Optional[List[str]] + List of the names of the covariates. + plot_kwargs : dict + Additional keyword arguments for the plot. Defaults to None. + Valid keys are: + - color_ref: matplotlib valid color for the 45 degree line + - color_scatter: matplotlib valid color for the scatter plot + axes : axes + Matplotlib axes. + + Returns + ------- + axes: matplotlib axes + """ indices = vi_results["indices"] preds = vi_results["preds"] preds_all = vi_results["preds_all"] if axes is None: - _, axes = _get_axes(grid, len(indices), False, True, None) + _, axes = _get_axes(grid, len(indices), True, True, figsize) + + if plot_kwargs is None: + plot_kwargs = {} + + if labels is None: + labels = vi_results["labels"] + + labels = ["+ " + ele if index != 0 else ele for index, ele in enumerate(labels)] - func = None if func is not None: preds = func(preds) preds_all = func(preds_all) @@ -1074,9 +1114,22 @@ def plot_scatter_submodels(vi_results, func=None, grid="long", axes=None): min_ = min(np.min(preds), np.min(preds_all)) max_ = max(np.max(preds), np.max(preds_all)) - for pred, ax in zip(preds, axes.ravel()): - ax.plot(pred, preds_all, ".", color="C0", alpha=0.1) - ax.axline([min_, min_], [max_, max_], color="0.5") + for pred, x_label, ax in zip(preds, labels, axes.ravel()): + ax.plot( + pred, + preds_all, + marker=plot_kwargs.get("marker_scatter", "."), + ls="", + color=plot_kwargs.get("color_scatter", "C0"), + alpha=plot_kwargs.get("alpha_scatter", 0.1), + ) + ax.set_xlabel(x_label) + ax.axline( + [min_, min_], + [max_, max_], + color=plot_kwargs.get("color_ref", "0.5"), + ls=plot_kwargs.get("ls_ref", "--"), + ) def generate_sequences(n_vars, i_var, include): diff --git a/tests/test_bart.py b/tests/test_bart.py index c10fc94..c64811a 100644 --- a/tests/test_bart.py +++ b/tests/test_bart.py @@ -193,8 +193,8 @@ def test_vi(self, kwargs): vi_results = pmb.compute_variable_importance( self.idata, bartrv=self.mu, X=self.X, samples=samples ) - pmb.plot_variable_importance(vi_results, X=self.X, **kwargs) - pmb.plot_scatter_submodels(vi_results) + pmb.plot_variable_importance(vi_results, **kwargs) + pmb.plot_scatter_submodels(vi_results, **kwargs) def test_pdp_pandas_labels(self): pd = pytest.importorskip("pandas")