@@ -824,10 +824,14 @@ def compute_variable_importance( # noqa: PLR0915 PLR0912
824
824
else :
825
825
shape = bartrv .eval ().shape [0 ]
826
826
827
+ n_vars = X .shape [1 ]
828
+
827
829
if hasattr (X , "columns" ) and hasattr (X , "to_numpy" ):
830
+ labels = X .columns
828
831
X = X .to_numpy ()
832
+ else :
833
+ labels = np .arange (n_vars ).astype (str )
829
834
830
- n_vars = X .shape [1 ]
831
835
r2_mean = np .zeros (n_vars )
832
836
r2_hdi = np .zeros ((n_vars , 2 ))
833
837
preds = np .zeros ((n_vars , samples , bartrv .eval ().shape [0 ]))
@@ -947,6 +951,7 @@ def compute_variable_importance( # noqa: PLR0915 PLR0912
947
951
948
952
vi_results = {
949
953
"indices" : indices ,
954
+ "labels" : labels [indices ],
950
955
"r2_mean" : r2_mean ,
951
956
"r2_hdi" : r2_hdi ,
952
957
"preds" : preds ,
@@ -957,7 +962,6 @@ def compute_variable_importance( # noqa: PLR0915 PLR0912
957
962
958
963
def plot_variable_importance (
959
964
vi_results : dict ,
960
- X : npt .NDArray [np .float64 ],
961
965
labels = None ,
962
966
figsize = None ,
963
967
plot_kwargs : Optional [Dict [str , Any ]] = None ,
@@ -1008,19 +1012,13 @@ def plot_variable_importance(
1008
1012
if figsize is None :
1009
1013
figsize = (8 , 3 )
1010
1014
1011
- if hasattr (X , "columns" ) and hasattr (X , "to_numpy" ):
1012
- labels = X .columns
1013
- X = X .to_numpy ()
1014
-
1015
1015
if ax is None :
1016
1016
_ , ax = plt .subplots (1 , 1 , figsize = figsize )
1017
1017
1018
1018
if labels is None :
1019
- labels = np .arange (n_vars ).astype (str )
1020
- else :
1021
- labels = np .asarray (labels )
1019
+ labels = vi_results ["labels" ]
1022
1020
1023
- new_labels = ["+ " + ele if index != 0 else ele for index , ele in enumerate (labels [ indices ] )]
1021
+ labels = ["+ " + ele if index != 0 else ele for index , ele in enumerate (labels )]
1024
1022
1025
1023
r_2_ref = np .array ([pearsonr2 (preds_all [j ], preds_all [j + 1 ]) for j in range (samples - 1 )])
1026
1024
@@ -1048,7 +1046,7 @@ def plot_variable_importance(
1048
1046
)
1049
1047
ax .set_xticks (
1050
1048
ticks ,
1051
- new_labels ,
1049
+ labels ,
1052
1050
rotation = plot_kwargs .get ("rotation" , 0 ),
1053
1051
)
1054
1052
ax .set_ylabel ("R²" , rotation = 0 , labelpad = 12 )
@@ -1060,9 +1058,9 @@ def plot_variable_importance(
1060
1058
1061
1059
def plot_scatter_submodels (
1062
1060
vi_results : dict ,
1063
- X : npt .NDArray [np .float64 ],
1064
1061
func : Optional [Callable ] = None ,
1065
1062
grid : str = "long" ,
1063
+ labels = None ,
1066
1064
figsize : Optional [Tuple [float , float ]] = None ,
1067
1065
plot_kwargs : Optional [Dict [str , Any ]] = None ,
1068
1066
axes : Optional [plt .Axes ] = None ,
@@ -1074,14 +1072,14 @@ def plot_scatter_submodels(
1074
1072
----------
1075
1073
vi_results: Dictionary
1076
1074
Dictionary computed with `compute_variable_importance`
1077
- X : npt.NDArray[np.float64]
1078
- The covariate matrix.
1079
1075
func : Optional[Callable], by default None.
1080
1076
Arbitrary function to apply to the predictions. Defaults to the identity function.
1081
1077
grid : str or tuple
1082
1078
How to arrange the subplots. Defaults to "long", one subplot below the other.
1083
1079
Other options are "wide", one subplot next to each other or a tuple indicating the number
1084
1080
of rows and columns.
1081
+ labels : Optional[List[str]]
1082
+ List of the names of the covariates.
1085
1083
plot_kwargs : dict
1086
1084
Additional keyword arguments for the plot. Defaults to None.
1087
1085
Valid keys are:
@@ -1097,23 +1095,17 @@ def plot_scatter_submodels(
1097
1095
indices = vi_results ["indices" ]
1098
1096
preds = vi_results ["preds" ]
1099
1097
preds_all = vi_results ["preds_all" ]
1100
- n_vars = len (indices )
1101
1098
1102
1099
if axes is None :
1103
1100
_ , axes = _get_axes (grid , len (indices ), True , True , figsize )
1104
1101
1105
1102
if plot_kwargs is None :
1106
1103
plot_kwargs = {}
1107
1104
1108
- if hasattr (X , "columns" ) and hasattr (X , "to_numpy" ):
1109
- labels = X .columns
1110
-
1111
1105
if labels is None :
1112
- labels = np .arange (n_vars ).astype (str )
1113
- else :
1114
- labels = np .asarray (labels )
1106
+ labels = vi_results ["labels" ]
1115
1107
1116
- new_labels = ["+ " + ele if index != 0 else ele for index , ele in enumerate (labels [ indices ] )]
1108
+ labels = ["+ " + ele if index != 0 else ele for index , ele in enumerate (labels )]
1117
1109
1118
1110
if func is not None :
1119
1111
preds = func (preds )
@@ -1122,7 +1114,7 @@ def plot_scatter_submodels(
1122
1114
min_ = min (np .min (preds ), np .min (preds_all ))
1123
1115
max_ = max (np .max (preds ), np .max (preds_all ))
1124
1116
1125
- for pred , x_label , ax in zip (preds , new_labels , axes .ravel ()):
1117
+ for pred , x_label , ax in zip (preds , labels , axes .ravel ()):
1126
1118
ax .plot (
1127
1119
pred ,
1128
1120
preds_all ,
0 commit comments