Skip to content

Commit 83f2409

Browse files
authored
fix bug, clean code and add comments (#132)
1 parent 9deb502 commit 83f2409

File tree

1 file changed

+46
-34
lines changed

1 file changed

+46
-34
lines changed

pymc_bart/utils.py

Lines changed: 46 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,5 @@
11
"""Utility function for variable selection and bart interpretability."""
22

3-
from itertools import combinations
43
import warnings
54
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
65

@@ -805,26 +804,35 @@ def plot_variable_importance(
805804
excluded=subset,
806805
shape=shape,
807806
)
808-
pearson = np.zeros(samples)
809-
for j in range(samples):
810-
pearson[j] = (
811-
pearsonr(predicted_all[j].flatten(), predicted_subset[j].flatten())[0]
812-
) ** 2
813-
r2_mean[idx] = np.mean(pearson)
814-
r2_hdi[idx] = az.hdi(pearson)
807+
r_2 = np.array(
808+
[
809+
pearsonr(predicted_all[j].flatten(), predicted_subset[j].flatten())[0] ** 2
810+
for j in range(samples)
811+
]
812+
)
813+
r2_mean[idx] = np.mean(r_2)
814+
r2_hdi[idx] = az.hdi(r_2)
815815

816816
elif method == "backward":
817817
r2_mean = np.zeros(n_vars)
818818
r2_hdi = np.zeros((n_vars, 2))
819819

820820
variables = set(range(n_vars))
821-
excluded: List[int] = []
821+
least_important_vars: List[int] = []
822822
indices = []
823823

824-
for i_var in range(0, n_vars):
825-
subsets = _generate_combinations(variables, excluded)
826-
max_pearson = -np.inf
824+
# Iterate over each variable to determine its contribution
825+
# least_important_vars tracks the variable with the lowest contribution
826+
# at the current stage. One new varible is added at each iteration.
827+
for i_var in range(n_vars):
828+
# Generate all possible subsets by adding one variable at a time to
829+
# least_important_vars
830+
subsets = generate_sequences(n_vars, i_var, least_important_vars)
831+
max_r_2 = -np.inf
832+
833+
# Iterate over each subset to find the one with the maximum Pearson correlation
827834
for subset in subsets:
835+
# Sample posterior predictions excluding a subset of variables
828836
predicted_subset = _sample_posterior(
829837
all_trees=all_trees,
830838
X=X,
@@ -833,25 +841,32 @@ def plot_variable_importance(
833841
excluded=subset,
834842
shape=shape,
835843
)
836-
pearson = np.zeros(samples)
844+
# Calculate Pearson correlation for each sample and find the mean
845+
r_2 = np.zeros(samples)
837846
for j in range(samples):
838-
pearson[j] = (
847+
r_2[j] = (
839848
pearsonr(predicted_all[j].flatten(), predicted_subset[j].flatten())[0]
840849
) ** 2
841-
mean_pearson = np.mean(pearson, dtype=float)
842-
if mean_pearson > max_pearson:
843-
max_pearson = mean_pearson
844-
best_subset = subset
845-
best_pearson = pearson
850+
mean_r_2 = np.mean(r_2, dtype=float)
851+
# Identify the least important combination of variables
852+
# based on the maximum mean squared Pearson correlation
853+
if mean_r_2 > max_r_2:
854+
max_r_2 = mean_r_2
855+
least_important_subset = subset
856+
r_2_without_least_important_vars = r_2
846857

847-
r2_mean[i_var] = max_pearson
848-
r2_hdi[i_var] = az.hdi(best_pearson)
858+
# Save values for plotting later
859+
r2_mean[i_var] = max_r_2
860+
r2_hdi[i_var] = az.hdi(r_2_without_least_important_vars)
849861

850-
indices.extend((set(best_subset) - set(indices)))
862+
# extend current list of least important variable
863+
least_important_vars += least_important_subset
851864

852-
excluded.append(best_subset)
865+
# add index of removed variable
866+
indices += list(set(least_important_subset) - set(indices))
853867

854-
indices.extend((set(variables) - set(indices)))
868+
# add remaining index
869+
indices += list(set(variables) - set(least_important_vars))
855870

856871
indices = indices[::-1]
857872
r2_mean = r2_mean[::-1]
@@ -878,13 +893,10 @@ def plot_variable_importance(
878893
return indices, ax
879894

880895

881-
def _generate_combinations(variables, excluded):
882-
"""
883-
Generate all possible combinations of variables.
884-
"""
885-
all_combinations = combinations(variables, len(excluded))
886-
valid_combinations = [
887-
com for com in all_combinations if not any(ele in com for ele in excluded)
888-
]
889-
890-
return valid_combinations
896+
def generate_sequences(n_vars, i_var, include):
897+
"""Generate combinations of variables"""
898+
if i_var:
899+
sequences = [tuple(include + [i]) for i in range(n_vars) if i not in include]
900+
else:
901+
sequences = [()]
902+
return sequences

0 commit comments

Comments
 (0)