@@ -824,10 +824,14 @@ def compute_variable_importance( # noqa: PLR0915 PLR0912
824
824
else :
825
825
shape = bartrv .eval ().shape [0 ]
826
826
827
+ n_vars = X .shape [1 ]
828
+
827
829
if hasattr (X , "columns" ) and hasattr (X , "to_numpy" ):
830
+ labels = X .columns
828
831
X = X .to_numpy ()
832
+ else :
833
+ labels = np .arange (n_vars ).astype (str )
829
834
830
- n_vars = X .shape [1 ]
831
835
r2_mean = np .zeros (n_vars )
832
836
r2_hdi = np .zeros ((n_vars , 2 ))
833
837
preds = np .zeros ((n_vars , samples , bartrv .eval ().shape [0 ]))
@@ -947,6 +951,7 @@ def compute_variable_importance( # noqa: PLR0915 PLR0912
947
951
948
952
vi_results = {
949
953
"indices" : indices ,
954
+ "labels" : labels [indices ],
950
955
"r2_mean" : r2_mean ,
951
956
"r2_hdi" : r2_hdi ,
952
957
"preds" : preds ,
@@ -957,7 +962,6 @@ def compute_variable_importance( # noqa: PLR0915 PLR0912
957
962
958
963
def plot_variable_importance (
959
964
vi_results : dict ,
960
- X : npt .NDArray [np .float64 ],
961
965
labels = None ,
962
966
figsize = None ,
963
967
plot_kwargs : Optional [Dict [str , Any ]] = None ,
@@ -1008,19 +1012,13 @@ def plot_variable_importance(
1008
1012
if figsize is None :
1009
1013
figsize = (8 , 3 )
1010
1014
1011
- if hasattr (X , "columns" ) and hasattr (X , "to_numpy" ):
1012
- labels = X .columns
1013
- X = X .to_numpy ()
1014
-
1015
1015
if ax is None :
1016
1016
_ , ax = plt .subplots (1 , 1 , figsize = figsize )
1017
1017
1018
1018
if labels is None :
1019
- labels = np .arange (n_vars ).astype (str )
1020
- else :
1021
- labels = np .asarray (labels )
1019
+ labels = vi_results ["labels" ]
1022
1020
1023
- new_labels = ["+ " + ele if index != 0 else ele for index , ele in enumerate (labels [ indices ] )]
1021
+ labels = ["+ " + ele if index != 0 else ele for index , ele in enumerate (labels )]
1024
1022
1025
1023
r_2_ref = np .array ([pearsonr2 (preds_all [j ], preds_all [j + 1 ]) for j in range (samples - 1 )])
1026
1024
@@ -1048,7 +1046,7 @@ def plot_variable_importance(
1048
1046
)
1049
1047
ax .set_xticks (
1050
1048
ticks ,
1051
- new_labels ,
1049
+ labels ,
1052
1050
rotation = plot_kwargs .get ("rotation" , 0 ),
1053
1051
)
1054
1052
ax .set_ylabel ("R²" , rotation = 0 , labelpad = 12 )
@@ -1058,25 +1056,80 @@ def plot_variable_importance(
1058
1056
return ax
1059
1057
1060
1058
1061
- def plot_scatter_submodels (vi_results , func = None , grid = "long" , axes = None ):
1059
+ def plot_scatter_submodels (
1060
+ vi_results : dict ,
1061
+ func : Optional [Callable ] = None ,
1062
+ grid : str = "long" ,
1063
+ labels = None ,
1064
+ figsize : Optional [Tuple [float , float ]] = None ,
1065
+ plot_kwargs : Optional [Dict [str , Any ]] = None ,
1066
+ axes : Optional [plt .Axes ] = None ,
1067
+ ):
1068
+ """
1069
+ Plot submodel's predictions against reference-model's predictions.
1070
+
1071
+ Parameters
1072
+ ----------
1073
+ vi_results: Dictionary
1074
+ Dictionary computed with `compute_variable_importance`
1075
+ func : Optional[Callable], by default None.
1076
+ Arbitrary function to apply to the predictions. Defaults to the identity function.
1077
+ grid : str or tuple
1078
+ How to arrange the subplots. Defaults to "long", one subplot below the other.
1079
+ Other options are "wide", one subplot next to each other or a tuple indicating the number
1080
+ of rows and columns.
1081
+ labels : Optional[List[str]]
1082
+ List of the names of the covariates.
1083
+ plot_kwargs : dict
1084
+ Additional keyword arguments for the plot. Defaults to None.
1085
+ Valid keys are:
1086
+ - color_ref: matplotlib valid color for the 45 degree line
1087
+ - color_scatter: matplotlib valid color for the scatter plot
1088
+ axes : axes
1089
+ Matplotlib axes.
1090
+
1091
+ Returns
1092
+ -------
1093
+ axes: matplotlib axes
1094
+ """
1062
1095
indices = vi_results ["indices" ]
1063
1096
preds = vi_results ["preds" ]
1064
1097
preds_all = vi_results ["preds_all" ]
1065
1098
1066
1099
if axes is None :
1067
- _ , axes = _get_axes (grid , len (indices ), False , True , None )
1100
+ _ , axes = _get_axes (grid , len (indices ), True , True , figsize )
1101
+
1102
+ if plot_kwargs is None :
1103
+ plot_kwargs = {}
1104
+
1105
+ if labels is None :
1106
+ labels = vi_results ["labels" ]
1107
+
1108
+ labels = ["+ " + ele if index != 0 else ele for index , ele in enumerate (labels )]
1068
1109
1069
- func = None
1070
1110
if func is not None :
1071
1111
preds = func (preds )
1072
1112
preds_all = func (preds_all )
1073
1113
1074
1114
min_ = min (np .min (preds ), np .min (preds_all ))
1075
1115
max_ = max (np .max (preds ), np .max (preds_all ))
1076
1116
1077
- for pred , ax in zip (preds , axes .ravel ()):
1078
- ax .plot (pred , preds_all , "." , color = "C0" , alpha = 0.1 )
1079
- ax .axline ([min_ , min_ ], [max_ , max_ ], color = "0.5" )
1117
+ for pred , x_label , ax in zip (preds , labels , axes .ravel ()):
1118
+ ax .plot (
1119
+ pred ,
1120
+ preds_all ,
1121
+ marker = plot_kwargs .get ("marker_scatter" , "." ),
1122
+ ls = "" ,
1123
+ color = plot_kwargs .get ("color_scatter" , "C0" ),
1124
+ alpha = plot_kwargs .get ("alpha_scatter" , 0.1 ),
1125
+ )
1126
+ ax .set_xlabel (x_label )
1127
+ ax .axline (
1128
+ [min_ , min_ ],
1129
+ [max_ , max_ ],
1130
+ color = plot_kwargs .get ("color_ref" , "0.5" ),
1131
+ ls = plot_kwargs .get ("ls_ref" , "--" ),
1132
+ )
1080
1133
1081
1134
1082
1135
def generate_sequences (n_vars , i_var , include ):
0 commit comments