@@ -693,6 +693,50 @@ def _smooth_mean(
693
693
return x_data , y_data
694
694
695
695
696
+ def get_variable_inclusion (idata , X , labels = None , to_kulprit = False ):
697
+ """
698
+ Get the normalized variable inclusion from BART model.
699
+
700
+ Parameters
701
+ ----------
702
+ idata : InferenceData
703
+ InferenceData containing a collection of BART_trees in sample_stats group
704
+ X : npt.NDArray
705
+ The covariate matrix.
706
+ labels : Optional[list[str]]
707
+ List of the names of the covariates. If X is a DataFrame the names of the covariables will
708
+ be taken from it and this argument will be ignored.
709
+ to_kulprit : bool
710
+ If True, the function will return a list of list with the variables names.
711
+ This list can be passed as a path to Kulprit's project method. Defaults to False.
712
+ Returns
713
+ -------
714
+ VI_norm : npt.NDArray
715
+ Normalized variable inclusion.
716
+ labels : list[str]
717
+ List of the names of the covariates.
718
+ """
719
+ VIs = idata ["sample_stats" ]["variable_inclusion" ].mean (("chain" , "draw" )).values
720
+ VI_norm = VIs / VIs .sum ()
721
+ idxs = np .argsort (VI_norm )
722
+
723
+ indices = idxs [::- 1 ]
724
+ n_vars = len (indices )
725
+
726
+ if hasattr (X , "columns" ) and hasattr (X , "to_numpy" ):
727
+ labels = X .columns
728
+
729
+ if labels is None :
730
+ labels = np .arange (n_vars ).astype (str )
731
+
732
+ label_list = labels .to_list ()
733
+
734
+ if to_kulprit :
735
+ return [label_list [:idx ] for idx in range (n_vars )]
736
+ else :
737
+ return VI_norm [indices ], label_list
738
+
739
+
696
740
def plot_variable_inclusion (idata , X , labels = None , figsize = None , plot_kwargs = None , ax = None ):
697
741
"""
698
742
Plot normalized variable inclusion from BART model.
@@ -720,26 +764,15 @@ def plot_variable_inclusion(idata, X, labels=None, figsize=None, plot_kwargs=Non
720
764
721
765
Returns
722
766
-------
723
- idxs: indexes of the covariates from higher to lower relative importance
724
767
axes: matplotlib axes
725
768
"""
726
769
if plot_kwargs is None :
727
770
plot_kwargs = {}
728
771
729
- VIs = idata ["sample_stats" ]["variable_inclusion" ].mean (("chain" , "draw" )).values
730
- VIs = VIs / VIs .sum ()
731
- idxs = np .argsort (VIs )
732
-
733
- indices = idxs [::- 1 ]
734
- n_vars = len (indices )
735
-
736
- if hasattr (X , "columns" ) and hasattr (X , "to_numpy" ):
737
- labels = X .columns
772
+ VI_norm , labels = get_variable_inclusion (idata , X , labels )
773
+ n_vars = len (labels )
738
774
739
- if labels is None :
740
- labels = np .arange (n_vars ).astype (str )
741
-
742
- new_labels = ["+ " + ele if index != 0 else ele for index , ele in enumerate (labels [indices ])]
775
+ new_labels = ["+ " + ele if index != 0 else ele for index , ele in enumerate (labels )]
743
776
744
777
ticks = np .arange (n_vars , dtype = int )
745
778
@@ -749,19 +782,18 @@ def plot_variable_inclusion(idata, X, labels=None, figsize=None, plot_kwargs=Non
749
782
if ax is None :
750
783
_ , ax = plt .subplots (1 , 1 , figsize = figsize )
751
784
785
+ ax .axhline (1 / n_vars , color = "0.5" , linestyle = "--" )
752
786
ax .plot (
753
- VIs [ indices ] ,
787
+ VI_norm ,
754
788
color = plot_kwargs .get ("color" , "k" ),
755
789
marker = plot_kwargs .get ("marker" , "o" ),
756
790
ls = plot_kwargs .get ("ls" , "-" ),
757
791
)
758
792
759
793
ax .set_xticks (ticks , new_labels , rotation = plot_kwargs .get ("rotation" , 0 ))
760
-
761
- ax .axhline (1 / n_vars , color = "0.5" , linestyle = "--" )
762
794
ax .set_ylim (0 , 1 )
763
795
764
- return idxs , ax
796
+ return ax
765
797
766
798
767
799
def compute_variable_importance ( # noqa: PLR0915 PLR0912
0 commit comments