1
1
"""Utility function for variable selection and bart interpretability."""
2
2
3
+ from itertools import combinations
3
4
import warnings
4
5
from typing import Any , Callable , Dict , List , Optional , Tuple , Union
5
6
12
13
from scipy .interpolate import griddata
13
14
from scipy .signal import savgol_filter
14
15
from scipy .stats import norm , pearsonr
15
- from xarray import concat
16
16
17
17
from .tree import Tree
18
18
@@ -71,7 +71,6 @@ def _sample_posterior(
71
71
for tree in odim_trees :
72
72
p [odim ] += tree .predict (x = X , excluded = excluded , shape = leaves_shape )
73
73
74
- # pred.reshape((*size_iter, shape, -1))
75
74
return pred .transpose ((0 , 3 , 1 , 2 )).reshape ((* size_iter , - 1 , shape ))
76
75
77
76
@@ -714,11 +713,12 @@ def plot_variable_importance(
714
713
bartrv : Variable ,
715
714
X : npt .NDArray [np .float_ ],
716
715
labels : Optional [List [str ]] = None ,
717
- sort_vars : bool = True ,
716
+ method : str = "VI" ,
718
717
figsize : Optional [Tuple [float , float ]] = None ,
718
+ xlabel_angle : float = 0 ,
719
719
samples : int = 100 ,
720
720
random_seed : Optional [int ] = None ,
721
- ) -> Tuple [npt . NDArray [ np . int_ ], List [plt .Axes ]]:
721
+ ) -> Tuple [List [ int ], List [plt .Axes ]]:
722
722
"""
723
723
Estimates variable importance from the BART-posterior.
724
724
@@ -733,10 +733,17 @@ def plot_variable_importance(
733
733
labels : Optional[List[str]]
734
734
List of the names of the covariates. If X is a DataFrame the names of the covariables will
735
735
be taken from it and this argument will be ignored.
736
- sort_vars : bool
737
- Whether to sort the variables according to their variable importance. Defaults to True.
736
+ method : str
737
+ Method used to rank variables. Available options are "VI" (default) and "backward".
738
+ The R squared will be computed following this ranking.
739
+ "VI" counts how many times each variable is included in the posterior distribution
740
+ of trees. "backward" uses a backward search based on the R squared.
741
+ VI requieres less computation time.
738
742
figsize : tuple
739
743
Figure size. If None it will be defined automatically.
744
+ xlabel_angle : float
745
+ rotation angle of the x-axis labels. Defaults to 0. Use values like 45 for
746
+ long labels and/or many variables.
740
747
samples : int
741
748
Number of predictions used to compute correlation for subsets of variables. Defaults to 100
742
749
random_seed : Optional[int]
@@ -747,7 +754,9 @@ def plot_variable_importance(
747
754
idxs: indexes of the covariates from higher to lower relative importance
748
755
axes: matplotlib axes
749
756
"""
750
- _ , axes = plt .subplots (2 , 1 , figsize = figsize )
757
+ rng = np .random .default_rng (random_seed )
758
+
759
+ all_trees = bartrv .owner .op .all_trees
751
760
752
761
if bartrv .ndim == 1 : # type: ignore
753
762
shape = 1
@@ -758,80 +767,124 @@ def plot_variable_importance(
758
767
labels = X .columns
759
768
X = X .values
760
769
761
- n_draws = idata ["posterior" ].dims ["draw" ]
762
- half = n_draws // 2
763
- f_half = idata ["sample_stats" ]["variable_inclusion" ].sel (draw = slice (0 , half - 1 ))
764
- s_half = idata ["sample_stats" ]["variable_inclusion" ].sel (draw = slice (half , n_draws ))
770
+ n_vars = X .shape [1 ]
771
+
772
+ if figsize is None :
773
+ figsize = (8 , 3 )
774
+
775
+ _ , ax = plt .subplots (1 , 1 , figsize = figsize )
765
776
766
- var_imp_chains = concat ([f_half , s_half ], dim = "chain" , join = "override" ).mean (("draw" )).values
767
- var_imp = idata ["sample_stats" ]["variable_inclusion" ].mean (("chain" , "draw" )).values
768
777
if labels is None :
769
- labels_ary = np .arange (len ( var_imp ) )
778
+ labels_ary = np .arange (n_vars ). astype ( str )
770
779
else :
771
780
labels_ary = np .array (labels )
772
781
773
- rng = np .random . default_rng ( random_seed )
782
+ ticks = np .arange ( n_vars , dtype = int )
774
783
775
- ticks = np .arange (len (var_imp ), dtype = int )
776
- idxs = np .argsort (var_imp )
777
- subsets = [idxs [:- i ].tolist () for i in range (1 , len (idxs ))]
778
- subsets .append (None ) # type: ignore
784
+ predicted_all = _sample_posterior (
785
+ all_trees , X = X , rng = rng , size = samples , excluded = None , shape = shape
786
+ )
779
787
780
- if sort_vars :
781
- indices = idxs [::- 1 ]
782
- else :
783
- indices = np .arange (len (var_imp ))
788
+ if method == "VI" :
789
+ idxs = np .argsort (
790
+ idata ["sample_stats" ]["variable_inclusion" ].mean (("chain" , "draw" )).values
791
+ )
792
+ subsets = [idxs [:- i ].tolist () for i in range (1 , len (idxs ))]
793
+ subsets .append (None ) # type: ignore
794
+
795
+ indices : List [int ] = list (idxs [::- 1 ])
796
+
797
+ r2_mean = np .zeros (n_vars )
798
+ r2_hdi = np .zeros ((n_vars , 2 ))
799
+ for idx , subset in enumerate (subsets ):
800
+ predicted_subset = _sample_posterior (
801
+ all_trees = all_trees ,
802
+ X = X ,
803
+ rng = rng ,
804
+ size = samples ,
805
+ excluded = subset ,
806
+ shape = shape ,
807
+ )
808
+ pearson = np .zeros (samples )
809
+ for j in range (samples ):
810
+ pearson [j ] = (
811
+ pearsonr (predicted_all [j ].flatten (), predicted_subset [j ].flatten ())[0 ]
812
+ ) ** 2
813
+ r2_mean [idx ] = np .mean (pearson )
814
+ r2_hdi [idx ] = az .hdi (pearson )
815
+
816
+ elif method == "backward" :
817
+ r2_mean = np .zeros (n_vars )
818
+ r2_hdi = np .zeros ((n_vars , 2 ))
819
+
820
+ variables = set (range (n_vars ))
821
+ excluded : List [int ] = []
822
+ indices = []
823
+
824
+ for i_var in range (0 , n_vars ):
825
+ subsets = _generate_combinations (variables , excluded )
826
+ max_pearson = - np .inf
827
+ for subset in subsets :
828
+ predicted_subset = _sample_posterior (
829
+ all_trees = all_trees ,
830
+ X = X ,
831
+ rng = rng ,
832
+ size = samples ,
833
+ excluded = subset ,
834
+ shape = shape ,
835
+ )
836
+ pearson = np .zeros (samples )
837
+ for j in range (samples ):
838
+ pearson [j ] = (
839
+ pearsonr (predicted_all [j ].flatten (), predicted_subset [j ].flatten ())[0 ]
840
+ ) ** 2
841
+ mean_pearson = np .mean (pearson , dtype = float )
842
+ if mean_pearson > max_pearson :
843
+ max_pearson = mean_pearson
844
+ best_subset = subset
845
+ best_pearson = pearson
784
846
785
- chains_mean = ( var_imp / var_imp . sum ())[ indices ]
786
- chains_hdi = az .hdi (( var_imp_chains . T / var_imp_chains . sum ( axis = 1 )). T )[ indices ]
847
+ r2_mean [ i_var ] = max_pearson
848
+ r2_hdi [ i_var ] = az .hdi (best_pearson )
787
849
788
- axes [0 ].errorbar (
850
+ indices .extend ((set (best_subset ) - set (indices )))
851
+
852
+ excluded .append (best_subset )
853
+
854
+ indices .extend ((set (variables ) - set (indices )))
855
+
856
+ indices = indices [::- 1 ]
857
+ r2_mean = r2_mean [::- 1 ]
858
+ r2_hdi = r2_hdi [::- 1 ]
859
+
860
+ new_labels = [
861
+ "+ " + ele if index != 0 else ele for index , ele in enumerate (labels_ary [indices ])
862
+ ]
863
+
864
+ r2_yerr_min = np .clip (r2_mean - r2_hdi [:, 0 ], 0 , None )
865
+ r2_yerr_max = np .clip (r2_hdi [:, 1 ] - r2_mean , 0 , None )
866
+ ax .errorbar (
789
867
ticks ,
790
- chains_mean ,
791
- np .array ((chains_mean - chains_hdi [:, 0 ], chains_hdi [:, 1 ] - chains_mean )),
868
+ r2_mean ,
869
+ np .array ((r2_yerr_min , r2_yerr_max )),
792
870
color = "C0" ,
793
871
)
794
- axes [0 ].set_xticks (ticks )
795
- axes [0 ].set_xticklabels (labels_ary [indices ])
796
- axes [0 ].set_xlabel ("covariables" )
797
- axes [0 ].set_ylabel ("importance" )
798
-
799
- all_trees = bartrv .owner .op .all_trees
872
+ ax .axhline (r2_mean [- 1 ], ls = "--" , color = "0.5" )
873
+ ax .set_xticks (ticks , new_labels , rotation = xlabel_angle )
874
+ ax .set_ylabel ("R²" , rotation = 0 , labelpad = 12 )
875
+ ax .set_ylim (0 , 1 )
876
+ ax .set_xlim (- 0.5 , n_vars - 0.5 )
800
877
801
- predicted_all = _sample_posterior (
802
- all_trees , X = X , rng = rng , size = samples , excluded = None , shape = shape
803
- )
878
+ return indices , ax
804
879
805
- ev_mean = np .zeros (len (var_imp ))
806
- ev_hdi = np .zeros ((len (var_imp ), 2 ))
807
- for idx , subset in enumerate (subsets ):
808
- predicted_subset = _sample_posterior (
809
- all_trees = all_trees ,
810
- X = X ,
811
- rng = rng ,
812
- size = samples ,
813
- excluded = subset ,
814
- shape = shape ,
815
- )
816
- pearson = np .zeros (samples )
817
- for j in range (samples ):
818
- pearson [j ] = (
819
- pearsonr (predicted_all [j ].flatten (), predicted_subset [j ].flatten ())[0 ]
820
- ) ** 2
821
- ev_mean [idx ] = np .mean (pearson )
822
- ev_hdi [idx ] = az .hdi (pearson )
823
-
824
- axes [1 ].errorbar (
825
- ticks , ev_mean , np .array ((ev_mean - ev_hdi [:, 0 ], ev_hdi [:, 1 ] - ev_mean )), color = "C0"
826
- )
827
- axes [1 ].axhline (ev_mean [- 1 ], ls = "--" , color = "0.5" )
828
- axes [1 ].set_xticks (ticks )
829
- axes [1 ].set_xticklabels (ticks + 1 )
830
- axes [1 ].set_xlabel ("number of covariables" )
831
- axes [1 ].set_ylabel ("R²" , rotation = 0 , labelpad = 12 )
832
- axes [1 ].set_ylim (0 , 1 )
833
880
834
- axes [0 ].set_xlim (- 0.5 , len (var_imp ) - 0.5 )
835
- axes [1 ].set_xlim (- 0.5 , len (var_imp ) - 0.5 )
881
+ def _generate_combinations (variables , excluded ):
882
+ """
883
+ Generate all possible combinations of variables.
884
+ """
885
+ all_combinations = combinations (variables , len (excluded ))
886
+ valid_combinations = [
887
+ com for com in all_combinations if not any (ele in com for ele in excluded )
888
+ ]
836
889
837
- return idxs [:: - 1 ], axes
890
+ return valid_combinations
0 commit comments