From d72c9638b89e87b8adfc171872258db906f70a95 Mon Sep 17 00:00:00 2001 From: aloctavodia Date: Tue, 26 Nov 2024 10:59:35 -0300 Subject: [PATCH 1/2] improve docs, aesthetics and functionality --- pymc_bart/utils.py | 73 ++++++++++++++++++++++++++++++++++++++++++---- 1 file changed, 67 insertions(+), 6 deletions(-) diff --git a/pymc_bart/utils.py b/pymc_bart/utils.py index e8c60bb..355ceb2 100644 --- a/pymc_bart/utils.py +++ b/pymc_bart/utils.py @@ -1058,15 +1058,63 @@ def plot_variable_importance( return ax -def plot_scatter_submodels(vi_results, func=None, grid="long", axes=None): +def plot_scatter_submodels( + vi_results: dict, + X: npt.NDArray[np.float64], + func: Optional[Callable] = None, + grid: str = "long", + figsize: Optional[Tuple[float, float]] = None, + plot_kwargs: Optional[Dict[str, Any]] = None, + axes: Optional[plt.Axes] = None, +): + """ + Plot submodel's predictions against reference-model's predictions. + + Parameters + ---------- + vi_results: Dictionary + Dictionary computed with `compute_variable_importance` + X : npt.NDArray[np.float64] + The covariate matrix. + func : Optional[Callable], by default None. + Arbitrary function to apply to the predictions. Defaults to the identity function. + grid : str or tuple + How to arrange the subplots. Defaults to "long", one subplot below the other. + Other options are "wide", one subplot next to each other or a tuple indicating the number + of rows and columns. + plot_kwargs : dict + Additional keyword arguments for the plot. Defaults to None. + Valid keys are: + - color_ref: matplotlib valid color for the 45 degree line + - color_scatter: matplotlib valid color for the scatter plot + axes : axes + Matplotlib axes. + + Returns + ------- + axes: matplotlib axes + """ indices = vi_results["indices"] preds = vi_results["preds"] preds_all = vi_results["preds_all"] + n_vars = len(indices) if axes is None: - _, axes = _get_axes(grid, len(indices), False, True, None) + _, axes = _get_axes(grid, len(indices), True, True, figsize) + + if plot_kwargs is None: + plot_kwargs = {} + + if hasattr(X, "columns") and hasattr(X, "to_numpy"): + labels = X.columns + + if labels is None: + labels = np.arange(n_vars).astype(str) + else: + labels = np.asarray(labels) + + new_labels = ["+ " + ele if index != 0 else ele for index, ele in enumerate(labels[indices])] - func = None if func is not None: preds = func(preds) preds_all = func(preds_all) @@ -1074,9 +1122,22 @@ def plot_scatter_submodels(vi_results, func=None, grid="long", axes=None): min_ = min(np.min(preds), np.min(preds_all)) max_ = max(np.max(preds), np.max(preds_all)) - for pred, ax in zip(preds, axes.ravel()): - ax.plot(pred, preds_all, ".", color="C0", alpha=0.1) - ax.axline([min_, min_], [max_, max_], color="0.5") + for pred, x_label, ax in zip(preds, new_labels, axes.ravel()): + ax.plot( + pred, + preds_all, + marker=plot_kwargs.get("marker_scatter", "."), + ls="", + color=plot_kwargs.get("color_scatter", "C0"), + alpha=plot_kwargs.get("alpha_scatter", 0.1), + ) + ax.set_xlabel(x_label) + ax.axline( + [min_, min_], + [max_, max_], + color=plot_kwargs.get("color_ref", "0.5"), + ls=plot_kwargs.get("ls_ref", "--"), + ) def generate_sequences(n_vars, i_var, include): From 38bf6b0bfd1e304b0cc4dd75695fe0739688f12a Mon Sep 17 00:00:00 2001 From: aloctavodia Date: Tue, 26 Nov 2024 20:26:20 -0300 Subject: [PATCH 2/2] remove X argument from plots --- pymc_bart/utils.py | 38 +++++++++++++++----------------------- tests/test_bart.py | 4 ++-- 2 files changed, 17 insertions(+), 25 deletions(-) diff --git a/pymc_bart/utils.py b/pymc_bart/utils.py index 355ceb2..31cc28f 100644 --- a/pymc_bart/utils.py +++ b/pymc_bart/utils.py @@ -824,10 +824,14 @@ def compute_variable_importance( # noqa: PLR0915 PLR0912 else: shape = bartrv.eval().shape[0] + n_vars = X.shape[1] + if hasattr(X, "columns") and hasattr(X, "to_numpy"): + labels = X.columns X = X.to_numpy() + else: + labels = np.arange(n_vars).astype(str) - n_vars = X.shape[1] r2_mean = np.zeros(n_vars) r2_hdi = np.zeros((n_vars, 2)) preds = np.zeros((n_vars, samples, bartrv.eval().shape[0])) @@ -947,6 +951,7 @@ def compute_variable_importance( # noqa: PLR0915 PLR0912 vi_results = { "indices": indices, + "labels": labels[indices], "r2_mean": r2_mean, "r2_hdi": r2_hdi, "preds": preds, @@ -957,7 +962,6 @@ def compute_variable_importance( # noqa: PLR0915 PLR0912 def plot_variable_importance( vi_results: dict, - X: npt.NDArray[np.float64], labels=None, figsize=None, plot_kwargs: Optional[Dict[str, Any]] = None, @@ -1008,19 +1012,13 @@ def plot_variable_importance( if figsize is None: figsize = (8, 3) - if hasattr(X, "columns") and hasattr(X, "to_numpy"): - labels = X.columns - X = X.to_numpy() - if ax is None: _, ax = plt.subplots(1, 1, figsize=figsize) if labels is None: - labels = np.arange(n_vars).astype(str) - else: - labels = np.asarray(labels) + labels = vi_results["labels"] - new_labels = ["+ " + ele if index != 0 else ele for index, ele in enumerate(labels[indices])] + labels = ["+ " + ele if index != 0 else ele for index, ele in enumerate(labels)] r_2_ref = np.array([pearsonr2(preds_all[j], preds_all[j + 1]) for j in range(samples - 1)]) @@ -1048,7 +1046,7 @@ def plot_variable_importance( ) ax.set_xticks( ticks, - new_labels, + labels, rotation=plot_kwargs.get("rotation", 0), ) ax.set_ylabel("R²", rotation=0, labelpad=12) @@ -1060,9 +1058,9 @@ def plot_variable_importance( def plot_scatter_submodels( vi_results: dict, - X: npt.NDArray[np.float64], func: Optional[Callable] = None, grid: str = "long", + labels=None, figsize: Optional[Tuple[float, float]] = None, plot_kwargs: Optional[Dict[str, Any]] = None, axes: Optional[plt.Axes] = None, @@ -1074,14 +1072,14 @@ def plot_scatter_submodels( ---------- vi_results: Dictionary Dictionary computed with `compute_variable_importance` - X : npt.NDArray[np.float64] - The covariate matrix. func : Optional[Callable], by default None. Arbitrary function to apply to the predictions. Defaults to the identity function. grid : str or tuple How to arrange the subplots. Defaults to "long", one subplot below the other. Other options are "wide", one subplot next to each other or a tuple indicating the number of rows and columns. + labels : Optional[List[str]] + List of the names of the covariates. plot_kwargs : dict Additional keyword arguments for the plot. Defaults to None. Valid keys are: @@ -1097,7 +1095,6 @@ def plot_scatter_submodels( indices = vi_results["indices"] preds = vi_results["preds"] preds_all = vi_results["preds_all"] - n_vars = len(indices) if axes is None: _, axes = _get_axes(grid, len(indices), True, True, figsize) @@ -1105,15 +1102,10 @@ def plot_scatter_submodels( if plot_kwargs is None: plot_kwargs = {} - if hasattr(X, "columns") and hasattr(X, "to_numpy"): - labels = X.columns - if labels is None: - labels = np.arange(n_vars).astype(str) - else: - labels = np.asarray(labels) + labels = vi_results["labels"] - new_labels = ["+ " + ele if index != 0 else ele for index, ele in enumerate(labels[indices])] + labels = ["+ " + ele if index != 0 else ele for index, ele in enumerate(labels)] if func is not None: preds = func(preds) @@ -1122,7 +1114,7 @@ def plot_scatter_submodels( min_ = min(np.min(preds), np.min(preds_all)) max_ = max(np.max(preds), np.max(preds_all)) - for pred, x_label, ax in zip(preds, new_labels, axes.ravel()): + for pred, x_label, ax in zip(preds, labels, axes.ravel()): ax.plot( pred, preds_all, diff --git a/tests/test_bart.py b/tests/test_bart.py index c10fc94..c64811a 100644 --- a/tests/test_bart.py +++ b/tests/test_bart.py @@ -193,8 +193,8 @@ def test_vi(self, kwargs): vi_results = pmb.compute_variable_importance( self.idata, bartrv=self.mu, X=self.X, samples=samples ) - pmb.plot_variable_importance(vi_results, X=self.X, **kwargs) - pmb.plot_scatter_submodels(vi_results) + pmb.plot_variable_importance(vi_results, **kwargs) + pmb.plot_scatter_submodels(vi_results, **kwargs) def test_pdp_pandas_labels(self): pd = pytest.importorskip("pandas")