Skip to content

Commit d7bbfb4

Browse files
authored
improve variable importance computation add backward method (#125)
* improve variable importance * fix tests
1 parent 2a8b12d commit d7bbfb4

File tree

4 files changed

+145
-82
lines changed

4 files changed

+145
-82
lines changed

pymc_bart/bart.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -92,9 +92,8 @@ class BART(Distribution):
9292
Controls the prior probability over the number of leaves of the trees.
9393
Should be positive.
9494
split_prior : Optional[List[float]], default None.
95-
Each element of split_prior should be in the [0, 1] interval and the elements should sum to
96-
1. Otherwise they will be normalized.
97-
Defaults to 0, i.e. all covariates have the same prior probability to be selected.
95+
List of positive numbers, one per column in input data.
96+
Defaults to None, all covariates have the same prior probability to be selected.
9897
split_rules : Optional[List[SplitRule]], default None
9998
List of SplitRule objects, one per column in input data.
10099
Allows using different split rules for different columns. Default is ContinuousSplitRule.

pymc_bart/pgbart.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -155,6 +155,11 @@ def __init__(
155155
else:
156156
self.split_rules = [ContinuousSplitRule] * self.X.shape[1]
157157

158+
jittered = np.random.normal(self.X, self.X.std(axis=0) / 12)
159+
min_values = np.min(self.X, axis=0)
160+
max_values = np.max(self.X, axis=0)
161+
self.X = np.clip(jittered, min_values, max_values)
162+
158163
init_mean = self.bart.Y.mean()
159164
self.num_observations = self.X.shape[0]
160165
self.num_variates = self.X.shape[1]

pymc_bart/utils.py

Lines changed: 121 additions & 68 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
"""Utility function for variable selection and bart interpretability."""
22

3+
from itertools import combinations
34
import warnings
45
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
56

@@ -12,7 +13,6 @@
1213
from scipy.interpolate import griddata
1314
from scipy.signal import savgol_filter
1415
from scipy.stats import norm, pearsonr
15-
from xarray import concat
1616

1717
from .tree import Tree
1818

@@ -71,7 +71,6 @@ def _sample_posterior(
7171
for tree in odim_trees:
7272
p[odim] += tree.predict(x=X, excluded=excluded, shape=leaves_shape)
7373

74-
# pred.reshape((*size_iter, shape, -1))
7574
return pred.transpose((0, 3, 1, 2)).reshape((*size_iter, -1, shape))
7675

7776

@@ -714,11 +713,12 @@ def plot_variable_importance(
714713
bartrv: Variable,
715714
X: npt.NDArray[np.float_],
716715
labels: Optional[List[str]] = None,
717-
sort_vars: bool = True,
716+
method: str = "VI",
718717
figsize: Optional[Tuple[float, float]] = None,
718+
xlabel_angle: float = 0,
719719
samples: int = 100,
720720
random_seed: Optional[int] = None,
721-
) -> Tuple[npt.NDArray[np.int_], List[plt.Axes]]:
721+
) -> Tuple[List[int], List[plt.Axes]]:
722722
"""
723723
Estimates variable importance from the BART-posterior.
724724
@@ -733,10 +733,17 @@ def plot_variable_importance(
733733
labels : Optional[List[str]]
734734
List of the names of the covariates. If X is a DataFrame the names of the covariables will
735735
be taken from it and this argument will be ignored.
736-
sort_vars : bool
737-
Whether to sort the variables according to their variable importance. Defaults to True.
736+
method : str
737+
Method used to rank variables. Available options are "VI" (default) and "backward".
738+
The R squared will be computed following this ranking.
739+
"VI" counts how many times each variable is included in the posterior distribution
740+
of trees. "backward" uses a backward search based on the R squared.
741+
VI requieres less computation time.
738742
figsize : tuple
739743
Figure size. If None it will be defined automatically.
744+
xlabel_angle : float
745+
rotation angle of the x-axis labels. Defaults to 0. Use values like 45 for
746+
long labels and/or many variables.
740747
samples : int
741748
Number of predictions used to compute correlation for subsets of variables. Defaults to 100
742749
random_seed : Optional[int]
@@ -747,7 +754,9 @@ def plot_variable_importance(
747754
idxs: indexes of the covariates from higher to lower relative importance
748755
axes: matplotlib axes
749756
"""
750-
_, axes = plt.subplots(2, 1, figsize=figsize)
757+
rng = np.random.default_rng(random_seed)
758+
759+
all_trees = bartrv.owner.op.all_trees
751760

752761
if bartrv.ndim == 1: # type: ignore
753762
shape = 1
@@ -758,80 +767,124 @@ def plot_variable_importance(
758767
labels = X.columns
759768
X = X.values
760769

761-
n_draws = idata["posterior"].dims["draw"]
762-
half = n_draws // 2
763-
f_half = idata["sample_stats"]["variable_inclusion"].sel(draw=slice(0, half - 1))
764-
s_half = idata["sample_stats"]["variable_inclusion"].sel(draw=slice(half, n_draws))
770+
n_vars = X.shape[1]
771+
772+
if figsize is None:
773+
figsize = (8, 3)
774+
775+
_, ax = plt.subplots(1, 1, figsize=figsize)
765776

766-
var_imp_chains = concat([f_half, s_half], dim="chain", join="override").mean(("draw")).values
767-
var_imp = idata["sample_stats"]["variable_inclusion"].mean(("chain", "draw")).values
768777
if labels is None:
769-
labels_ary = np.arange(len(var_imp))
778+
labels_ary = np.arange(n_vars).astype(str)
770779
else:
771780
labels_ary = np.array(labels)
772781

773-
rng = np.random.default_rng(random_seed)
782+
ticks = np.arange(n_vars, dtype=int)
774783

775-
ticks = np.arange(len(var_imp), dtype=int)
776-
idxs = np.argsort(var_imp)
777-
subsets = [idxs[:-i].tolist() for i in range(1, len(idxs))]
778-
subsets.append(None) # type: ignore
784+
predicted_all = _sample_posterior(
785+
all_trees, X=X, rng=rng, size=samples, excluded=None, shape=shape
786+
)
779787

780-
if sort_vars:
781-
indices = idxs[::-1]
782-
else:
783-
indices = np.arange(len(var_imp))
788+
if method == "VI":
789+
idxs = np.argsort(
790+
idata["sample_stats"]["variable_inclusion"].mean(("chain", "draw")).values
791+
)
792+
subsets = [idxs[:-i].tolist() for i in range(1, len(idxs))]
793+
subsets.append(None) # type: ignore
794+
795+
indices: List[int] = list(idxs[::-1])
796+
797+
r2_mean = np.zeros(n_vars)
798+
r2_hdi = np.zeros((n_vars, 2))
799+
for idx, subset in enumerate(subsets):
800+
predicted_subset = _sample_posterior(
801+
all_trees=all_trees,
802+
X=X,
803+
rng=rng,
804+
size=samples,
805+
excluded=subset,
806+
shape=shape,
807+
)
808+
pearson = np.zeros(samples)
809+
for j in range(samples):
810+
pearson[j] = (
811+
pearsonr(predicted_all[j].flatten(), predicted_subset[j].flatten())[0]
812+
) ** 2
813+
r2_mean[idx] = np.mean(pearson)
814+
r2_hdi[idx] = az.hdi(pearson)
815+
816+
elif method == "backward":
817+
r2_mean = np.zeros(n_vars)
818+
r2_hdi = np.zeros((n_vars, 2))
819+
820+
variables = set(range(n_vars))
821+
excluded: List[int] = []
822+
indices = []
823+
824+
for i_var in range(0, n_vars):
825+
subsets = _generate_combinations(variables, excluded)
826+
max_pearson = -np.inf
827+
for subset in subsets:
828+
predicted_subset = _sample_posterior(
829+
all_trees=all_trees,
830+
X=X,
831+
rng=rng,
832+
size=samples,
833+
excluded=subset,
834+
shape=shape,
835+
)
836+
pearson = np.zeros(samples)
837+
for j in range(samples):
838+
pearson[j] = (
839+
pearsonr(predicted_all[j].flatten(), predicted_subset[j].flatten())[0]
840+
) ** 2
841+
mean_pearson = np.mean(pearson, dtype=float)
842+
if mean_pearson > max_pearson:
843+
max_pearson = mean_pearson
844+
best_subset = subset
845+
best_pearson = pearson
784846

785-
chains_mean = (var_imp / var_imp.sum())[indices]
786-
chains_hdi = az.hdi((var_imp_chains.T / var_imp_chains.sum(axis=1)).T)[indices]
847+
r2_mean[i_var] = max_pearson
848+
r2_hdi[i_var] = az.hdi(best_pearson)
787849

788-
axes[0].errorbar(
850+
indices.extend((set(best_subset) - set(indices)))
851+
852+
excluded.append(best_subset)
853+
854+
indices.extend((set(variables) - set(indices)))
855+
856+
indices = indices[::-1]
857+
r2_mean = r2_mean[::-1]
858+
r2_hdi = r2_hdi[::-1]
859+
860+
new_labels = [
861+
"+ " + ele if index != 0 else ele for index, ele in enumerate(labels_ary[indices])
862+
]
863+
864+
r2_yerr_min = np.clip(r2_mean - r2_hdi[:, 0], 0, None)
865+
r2_yerr_max = np.clip(r2_hdi[:, 1] - r2_mean, 0, None)
866+
ax.errorbar(
789867
ticks,
790-
chains_mean,
791-
np.array((chains_mean - chains_hdi[:, 0], chains_hdi[:, 1] - chains_mean)),
868+
r2_mean,
869+
np.array((r2_yerr_min, r2_yerr_max)),
792870
color="C0",
793871
)
794-
axes[0].set_xticks(ticks)
795-
axes[0].set_xticklabels(labels_ary[indices])
796-
axes[0].set_xlabel("covariables")
797-
axes[0].set_ylabel("importance")
798-
799-
all_trees = bartrv.owner.op.all_trees
872+
ax.axhline(r2_mean[-1], ls="--", color="0.5")
873+
ax.set_xticks(ticks, new_labels, rotation=xlabel_angle)
874+
ax.set_ylabel("R²", rotation=0, labelpad=12)
875+
ax.set_ylim(0, 1)
876+
ax.set_xlim(-0.5, n_vars - 0.5)
800877

801-
predicted_all = _sample_posterior(
802-
all_trees, X=X, rng=rng, size=samples, excluded=None, shape=shape
803-
)
878+
return indices, ax
804879

805-
ev_mean = np.zeros(len(var_imp))
806-
ev_hdi = np.zeros((len(var_imp), 2))
807-
for idx, subset in enumerate(subsets):
808-
predicted_subset = _sample_posterior(
809-
all_trees=all_trees,
810-
X=X,
811-
rng=rng,
812-
size=samples,
813-
excluded=subset,
814-
shape=shape,
815-
)
816-
pearson = np.zeros(samples)
817-
for j in range(samples):
818-
pearson[j] = (
819-
pearsonr(predicted_all[j].flatten(), predicted_subset[j].flatten())[0]
820-
) ** 2
821-
ev_mean[idx] = np.mean(pearson)
822-
ev_hdi[idx] = az.hdi(pearson)
823-
824-
axes[1].errorbar(
825-
ticks, ev_mean, np.array((ev_mean - ev_hdi[:, 0], ev_hdi[:, 1] - ev_mean)), color="C0"
826-
)
827-
axes[1].axhline(ev_mean[-1], ls="--", color="0.5")
828-
axes[1].set_xticks(ticks)
829-
axes[1].set_xticklabels(ticks + 1)
830-
axes[1].set_xlabel("number of covariables")
831-
axes[1].set_ylabel("R²", rotation=0, labelpad=12)
832-
axes[1].set_ylim(0, 1)
833880

834-
axes[0].set_xlim(-0.5, len(var_imp) - 0.5)
835-
axes[1].set_xlim(-0.5, len(var_imp) - 0.5)
881+
def _generate_combinations(variables, excluded):
882+
"""
883+
Generate all possible combinations of variables.
884+
"""
885+
all_combinations = combinations(variables, len(excluded))
886+
valid_combinations = [
887+
com for com in all_combinations if not any(ele in com for ele in excluded)
888+
]
836889

837-
return idxs[::-1], axes
890+
return valid_combinations

tests/test_bart.py

Lines changed: 17 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -98,7 +98,7 @@ def test_shared_variable(response):
9898
y = pm.Normal("y", mu, sigma, observed=Y, shape=mu.shape)
9999
idata = pm.sample(tune=100, draws=100, chains=2, random_seed=3415)
100100
ppc = pm.sample_posterior_predictive(idata)
101-
new_X = pm.set_data({"data_X": X[:3]})
101+
pm.set_data({"data_X": X[:3]})
102102
ppc2 = pm.sample_posterior_predictive(idata)
103103

104104
assert ppc.posterior_predictive["y"].shape == (2, 100, 50)
@@ -160,7 +160,7 @@ def test_sample_posterior(self):
160160
{"instances": 2},
161161
{"var_idx": [0], "smooth": False, "color": "k"},
162162
{"grid": (1, 2), "sharey": "none", "alpha": 1},
163-
{"var_discrete": [0]}
163+
{"var_discrete": [0]},
164164
],
165165
)
166166
def test_ice(self, kwargs):
@@ -178,7 +178,7 @@ def test_ice(self, kwargs):
178178
},
179179
{"var_idx": [0], "smooth": False, "color": "k"},
180180
{"grid": (1, 2), "sharey": "none", "alpha": 1},
181-
{"var_discrete": [0]}
181+
{"var_discrete": [0]},
182182
],
183183
)
184184
def test_pdp(self, kwargs):
@@ -224,22 +224,28 @@ def test_bart_moment(size, expected):
224224
@pytest.mark.parametrize(
225225
argnames="separate_trees,split_rule",
226226
argvalues=[
227-
(False,pmb.ContinuousSplitRule),
228-
(False,pmb.OneHotSplitRule),
229-
(False,pmb.SubsetSplitRule),
230-
(True,pmb.ContinuousSplitRule)
227+
(False, pmb.ContinuousSplitRule),
228+
(False, pmb.OneHotSplitRule),
229+
(False, pmb.SubsetSplitRule),
230+
(True, pmb.ContinuousSplitRule),
231231
],
232232
ids=["continuous", "one-hot", "subset", "separate-trees"],
233233
)
234-
def test_categorical_model(separate_trees,split_rule):
234+
def test_categorical_model(separate_trees, split_rule):
235235

236236
Y = np.array([0, 0, 0, 1, 1, 1, 2, 2, 2])
237237
X = np.concatenate([Y[:, None], np.random.randint(0, 6, size=(9, 4))], axis=1)
238238

239239
with pm.Model() as model:
240-
lo = pmb.BART("logodds", X, Y, m=2, shape=(3, 9),
241-
split_rules=[split_rule]*5,
242-
separate_trees=separate_trees)
240+
lo = pmb.BART(
241+
"logodds",
242+
X,
243+
Y,
244+
m=2,
245+
shape=(3, 9),
246+
split_rules=[split_rule] * 5,
247+
separate_trees=separate_trees,
248+
)
243249
y = pm.Categorical("y", p=pm.math.softmax(lo.T, axis=-1), observed=Y)
244250
idata = pm.sample(random_seed=3415, tune=300, draws=300)
245251
idata = pm.sample_posterior_predictive(idata, predictions=True, extend_inferencedata=True)

0 commit comments

Comments
 (0)