8
8
import numpy as np
9
9
import numpy .typing as npt
10
10
import pytensor .tensor as pt
11
+ from numba import jit
11
12
from pytensor .tensor .variable import Variable
12
13
from scipy .interpolate import griddata
13
14
from scipy .signal import savgol_filter
14
- from scipy .stats import norm , pearsonr
15
+ from scipy .stats import norm
15
16
16
17
from .tree import Tree
17
18
@@ -700,8 +701,9 @@ def plot_variable_importance( # noqa: PLR0915
700
701
method : str = "VI" ,
701
702
figsize : Optional [Tuple [float , float ]] = None ,
702
703
xlabel_angle : float = 0 ,
703
- samples : int = 100 ,
704
+ samples : int = 50 ,
704
705
random_seed : Optional [int ] = None ,
706
+ plot_kwargs : Optional [Dict [str , Any ]] = None ,
705
707
ax : Optional [plt .Axes ] = None ,
706
708
) -> Tuple [List [int ], Union [List [plt .Axes ], Any ]]:
707
709
"""
@@ -733,6 +735,14 @@ def plot_variable_importance( # noqa: PLR0915
733
735
Number of predictions used to compute correlation for subsets of variables. Defaults to 100
734
736
random_seed : Optional[int]
735
737
random_seed used to sample from the posterior. Defaults to None.
738
+ plot_kwargs : dict
739
+ Additional keyword arguments for the plot. Defaults to None.
740
+ Valid keys are:
741
+ - color_r2: matplotlib valid color for error bars
742
+ - marker_r2: matplotlib valid marker for the mean R squared
743
+ - marker_fc_r2: matplotlib valid marker face color for the mean R squared
744
+ - ls_ref: matplotlib valid linestyle for the reference line
745
+ - color_ref: matplotlib valid color for the reference line
736
746
ax : axes
737
747
Matplotlib axes.
738
748
@@ -745,6 +755,9 @@ def plot_variable_importance( # noqa: PLR0915
745
755
746
756
all_trees = bartrv .owner .op .all_trees
747
757
758
+ if plot_kwargs is None :
759
+ plot_kwargs = {}
760
+
748
761
if bartrv .ndim == 1 : # type: ignore
749
762
shape = 1
750
763
else :
@@ -773,6 +786,10 @@ def plot_variable_importance( # noqa: PLR0915
773
786
all_trees , X = X , rng = rng , size = samples , excluded = None , shape = shape
774
787
)
775
788
789
+ r_2_ref = np .array (
790
+ [pearsonr2 (predicted_all [j ], predicted_all [j + 1 ]) for j in range (samples - 1 )]
791
+ )
792
+
776
793
if method == "VI" :
777
794
idxs = np .argsort (
778
795
idata ["sample_stats" ]["variable_inclusion" ].mean (("chain" , "draw" )).values
@@ -794,10 +811,7 @@ def plot_variable_importance( # noqa: PLR0915
794
811
shape = shape ,
795
812
)
796
813
r_2 = np .array (
797
- [
798
- pearsonr (predicted_all [j ].flatten (), predicted_subset [j ].flatten ())[0 ] ** 2
799
- for j in range (samples )
800
- ]
814
+ [pearsonr2 (predicted_all [j ], predicted_subset [j ]) for j in range (samples )]
801
815
)
802
816
r2_mean [idx ] = np .mean (r_2 )
803
817
r2_hdi [idx ] = az .hdi (r_2 )
@@ -833,10 +847,7 @@ def plot_variable_importance( # noqa: PLR0915
833
847
# Calculate Pearson correlation for each sample and find the mean
834
848
r_2 = np .zeros (samples )
835
849
for j in range (samples ):
836
- r_2 [j ] = (
837
- (pearsonr (predicted_all [j ].flatten (), predicted_subset [j ].flatten ())[0 ])
838
- ** 2
839
- )
850
+ r_2 [j ] = pearsonr2 (predicted_all [j ], predicted_subset [j ])
840
851
mean_r_2 = np .mean (r_2 , dtype = float )
841
852
# Identify the least important combination of variables
842
853
# based on the maximum mean squared Pearson correlation
@@ -872,9 +883,21 @@ def plot_variable_importance( # noqa: PLR0915
872
883
ticks ,
873
884
r2_mean ,
874
885
np .array ((r2_yerr_min , r2_yerr_max )),
875
- color = "C0" ,
886
+ color = plot_kwargs .get ("color_r2" , "k" ),
887
+ fmt = plot_kwargs .get ("marker_r2" , "o" ),
888
+ mfc = plot_kwargs .get ("marker_fc_r2" , "white" ),
889
+ )
890
+ ax .axhline (
891
+ np .mean (r_2_ref ),
892
+ ls = plot_kwargs .get ("ls_ref" , "--" ),
893
+ color = plot_kwargs .get ("color_ref" , "grey" ),
894
+ )
895
+ ax .fill_between (
896
+ [- 0.5 , n_vars - 0.5 ],
897
+ * az .hdi (r_2_ref ),
898
+ alpha = 0.1 ,
899
+ color = plot_kwargs .get ("color_ref" , "grey" ),
876
900
)
877
- ax .axhline (r2_mean [- 1 ], ls = "--" , color = "0.5" )
878
901
ax .set_xticks (ticks , new_labels , rotation = xlabel_angle )
879
902
ax .set_ylabel ("R²" , rotation = 0 , labelpad = 12 )
880
903
ax .set_ylim (0 , 1 )
@@ -890,3 +913,13 @@ def generate_sequences(n_vars, i_var, include):
890
913
else :
891
914
sequences = [()]
892
915
return sequences
916
+
917
+
918
+ @jit (nopython = True )
919
+ def pearsonr2 (A , B ):
920
+ """Compute the squared Pearson correlation coefficient"""
921
+ A = A .flatten ()
922
+ B = B .flatten ()
923
+ am = A - np .mean (A )
924
+ bm = B - np .mean (B )
925
+ return (am @ bm ) ** 2 / (np .sum (am ** 2 ) * np .sum (bm ** 2 ))
0 commit comments