Skip to content

Commit 064457e

Browse files
authored
Adds get_variable_inclusion function (#214)
* add get_variable_inclusion function * add elements to API reference
1 parent 3bad2c6 commit 064457e

File tree

3 files changed

+53
-19
lines changed

3 files changed

+53
-19
lines changed

docs/api_reference.rst

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,4 +13,4 @@ methods in the current release of PyMC-BART.
1313
=============================
1414

1515
.. automodule:: pymc_bart
16-
:members: BART, PGBART, plot_pdp, plot_ice, plot_variable_importance, plot_convergence, ContinuousSplitRule, OneHotSplitRule, SubsetSplitRule
16+
:members: BART, PGBART, compute_variable_importance, get_variable_inclusion, plot_convergence, plot_ice, plot_pdp, plot_scatter_submodels, plot_variable_importance, plot_variable_inclusion, ContinuousSplitRule, OneHotSplitRule, SubsetSplitRule

pymc_bart/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
from pymc_bart.split_rules import ContinuousSplitRule, OneHotSplitRule, SubsetSplitRule
1919
from pymc_bart.utils import (
2020
compute_variable_importance,
21+
get_variable_inclusion,
2122
plot_convergence,
2223
plot_ice,
2324
plot_pdp,
@@ -33,6 +34,7 @@
3334
"OneHotSplitRule",
3435
"SubsetSplitRule",
3536
"compute_variable_importance",
37+
"get_variable_inclusion",
3638
"plot_convergence",
3739
"plot_ice",
3840
"plot_pdp",

pymc_bart/utils.py

Lines changed: 50 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -693,6 +693,50 @@ def _smooth_mean(
693693
return x_data, y_data
694694

695695

696+
def get_variable_inclusion(idata, X, labels=None, to_kulprit=False):
697+
"""
698+
Get the normalized variable inclusion from BART model.
699+
700+
Parameters
701+
----------
702+
idata : InferenceData
703+
InferenceData containing a collection of BART_trees in sample_stats group
704+
X : npt.NDArray
705+
The covariate matrix.
706+
labels : Optional[list[str]]
707+
List of the names of the covariates. If X is a DataFrame the names of the covariables will
708+
be taken from it and this argument will be ignored.
709+
to_kulprit : bool
710+
If True, the function will return a list of list with the variables names.
711+
This list can be passed as a path to Kulprit's project method. Defaults to False.
712+
Returns
713+
-------
714+
VI_norm : npt.NDArray
715+
Normalized variable inclusion.
716+
labels : list[str]
717+
List of the names of the covariates.
718+
"""
719+
VIs = idata["sample_stats"]["variable_inclusion"].mean(("chain", "draw")).values
720+
VI_norm = VIs / VIs.sum()
721+
idxs = np.argsort(VI_norm)
722+
723+
indices = idxs[::-1]
724+
n_vars = len(indices)
725+
726+
if hasattr(X, "columns") and hasattr(X, "to_numpy"):
727+
labels = X.columns
728+
729+
if labels is None:
730+
labels = np.arange(n_vars).astype(str)
731+
732+
label_list = labels.to_list()
733+
734+
if to_kulprit:
735+
return [label_list[:idx] for idx in range(n_vars)]
736+
else:
737+
return VI_norm[indices], label_list
738+
739+
696740
def plot_variable_inclusion(idata, X, labels=None, figsize=None, plot_kwargs=None, ax=None):
697741
"""
698742
Plot normalized variable inclusion from BART model.
@@ -720,26 +764,15 @@ def plot_variable_inclusion(idata, X, labels=None, figsize=None, plot_kwargs=Non
720764
721765
Returns
722766
-------
723-
idxs: indexes of the covariates from higher to lower relative importance
724767
axes: matplotlib axes
725768
"""
726769
if plot_kwargs is None:
727770
plot_kwargs = {}
728771

729-
VIs = idata["sample_stats"]["variable_inclusion"].mean(("chain", "draw")).values
730-
VIs = VIs / VIs.sum()
731-
idxs = np.argsort(VIs)
732-
733-
indices = idxs[::-1]
734-
n_vars = len(indices)
735-
736-
if hasattr(X, "columns") and hasattr(X, "to_numpy"):
737-
labels = X.columns
772+
VI_norm, labels = get_variable_inclusion(idata, X, labels)
773+
n_vars = len(labels)
738774

739-
if labels is None:
740-
labels = np.arange(n_vars).astype(str)
741-
742-
new_labels = ["+ " + ele if index != 0 else ele for index, ele in enumerate(labels[indices])]
775+
new_labels = ["+ " + ele if index != 0 else ele for index, ele in enumerate(labels)]
743776

744777
ticks = np.arange(n_vars, dtype=int)
745778

@@ -749,19 +782,18 @@ def plot_variable_inclusion(idata, X, labels=None, figsize=None, plot_kwargs=Non
749782
if ax is None:
750783
_, ax = plt.subplots(1, 1, figsize=figsize)
751784

785+
ax.axhline(1 / n_vars, color="0.5", linestyle="--")
752786
ax.plot(
753-
VIs[indices],
787+
VI_norm,
754788
color=plot_kwargs.get("color", "k"),
755789
marker=plot_kwargs.get("marker", "o"),
756790
ls=plot_kwargs.get("ls", "-"),
757791
)
758792

759793
ax.set_xticks(ticks, new_labels, rotation=plot_kwargs.get("rotation", 0))
760-
761-
ax.axhline(1 / n_vars, color="0.5", linestyle="--")
762794
ax.set_ylim(0, 1)
763795

764-
return idxs, ax
796+
return ax
765797

766798

767799
def compute_variable_importance( # noqa: PLR0915 PLR0912

0 commit comments

Comments
 (0)