1
1
"""Utility function for variable selection and bart interpretability."""
2
2
3
- from itertools import combinations
4
3
import warnings
5
4
from typing import Any , Callable , Dict , List , Optional , Tuple , Union
6
5
@@ -805,26 +804,35 @@ def plot_variable_importance(
805
804
excluded = subset ,
806
805
shape = shape ,
807
806
)
808
- pearson = np .zeros (samples )
809
- for j in range (samples ):
810
- pearson [j ] = (
811
- pearsonr (predicted_all [j ].flatten (), predicted_subset [j ].flatten ())[0 ]
812
- ) ** 2
813
- r2_mean [idx ] = np .mean (pearson )
814
- r2_hdi [idx ] = az .hdi (pearson )
807
+ r_2 = np .array (
808
+ [
809
+ pearsonr (predicted_all [j ].flatten (), predicted_subset [j ].flatten ())[0 ] ** 2
810
+ for j in range (samples )
811
+ ]
812
+ )
813
+ r2_mean [idx ] = np .mean (r_2 )
814
+ r2_hdi [idx ] = az .hdi (r_2 )
815
815
816
816
elif method == "backward" :
817
817
r2_mean = np .zeros (n_vars )
818
818
r2_hdi = np .zeros ((n_vars , 2 ))
819
819
820
820
variables = set (range (n_vars ))
821
- excluded : List [int ] = []
821
+ least_important_vars : List [int ] = []
822
822
indices = []
823
823
824
- for i_var in range (0 , n_vars ):
825
- subsets = _generate_combinations (variables , excluded )
826
- max_pearson = - np .inf
824
+ # Iterate over each variable to determine its contribution
825
+ # least_important_vars tracks the variable with the lowest contribution
826
+ # at the current stage. One new varible is added at each iteration.
827
+ for i_var in range (n_vars ):
828
+ # Generate all possible subsets by adding one variable at a time to
829
+ # least_important_vars
830
+ subsets = generate_sequences (n_vars , i_var , least_important_vars )
831
+ max_r_2 = - np .inf
832
+
833
+ # Iterate over each subset to find the one with the maximum Pearson correlation
827
834
for subset in subsets :
835
+ # Sample posterior predictions excluding a subset of variables
828
836
predicted_subset = _sample_posterior (
829
837
all_trees = all_trees ,
830
838
X = X ,
@@ -833,25 +841,32 @@ def plot_variable_importance(
833
841
excluded = subset ,
834
842
shape = shape ,
835
843
)
836
- pearson = np .zeros (samples )
844
+ # Calculate Pearson correlation for each sample and find the mean
845
+ r_2 = np .zeros (samples )
837
846
for j in range (samples ):
838
- pearson [j ] = (
847
+ r_2 [j ] = (
839
848
pearsonr (predicted_all [j ].flatten (), predicted_subset [j ].flatten ())[0 ]
840
849
) ** 2
841
- mean_pearson = np .mean (pearson , dtype = float )
842
- if mean_pearson > max_pearson :
843
- max_pearson = mean_pearson
844
- best_subset = subset
845
- best_pearson = pearson
850
+ mean_r_2 = np .mean (r_2 , dtype = float )
851
+ # Identify the least important combination of variables
852
+ # based on the maximum mean squared Pearson correlation
853
+ if mean_r_2 > max_r_2 :
854
+ max_r_2 = mean_r_2
855
+ least_important_subset = subset
856
+ r_2_without_least_important_vars = r_2
846
857
847
- r2_mean [i_var ] = max_pearson
848
- r2_hdi [i_var ] = az .hdi (best_pearson )
858
+ # Save values for plotting later
859
+ r2_mean [i_var ] = max_r_2
860
+ r2_hdi [i_var ] = az .hdi (r_2_without_least_important_vars )
849
861
850
- indices .extend ((set (best_subset ) - set (indices )))
862
+ # extend current list of least important variable
863
+ least_important_vars += least_important_subset
851
864
852
- excluded .append (best_subset )
865
+ # add index of removed variable
866
+ indices += list (set (least_important_subset ) - set (indices ))
853
867
854
- indices .extend ((set (variables ) - set (indices )))
868
+ # add remaining index
869
+ indices += list (set (variables ) - set (least_important_vars ))
855
870
856
871
indices = indices [::- 1 ]
857
872
r2_mean = r2_mean [::- 1 ]
@@ -878,13 +893,10 @@ def plot_variable_importance(
878
893
return indices , ax
879
894
880
895
881
- def _generate_combinations (variables , excluded ):
882
- """
883
- Generate all possible combinations of variables.
884
- """
885
- all_combinations = combinations (variables , len (excluded ))
886
- valid_combinations = [
887
- com for com in all_combinations if not any (ele in com for ele in excluded )
888
- ]
889
-
890
- return valid_combinations
896
+ def generate_sequences (n_vars , i_var , include ):
897
+ """Generate combinations of variables"""
898
+ if i_var :
899
+ sequences = [tuple (include + [i ]) for i in range (n_vars ) if i not in include ]
900
+ else :
901
+ sequences = [()]
902
+ return sequences
0 commit comments