@@ -705,7 +705,7 @@ def plot_variable_inclusion(idata, X, labels=None, figsize=None, plot_kwargs=Non
705
705
706
706
Parameters
707
707
----------
708
- idata: InferenceData
708
+ idata : InferenceData
709
709
InferenceData containing a collection of BART_trees in sample_stats group
710
710
X : npt.NDArray[np.float64]
711
711
The covariate matrix.
@@ -784,7 +784,7 @@ def compute_variable_importance( # noqa: PLR0915 PLR0912
784
784
785
785
Parameters
786
786
----------
787
- idata: InferenceData
787
+ idata : InferenceData
788
788
InferenceData containing a collection of BART_trees in sample_stats group
789
789
bartrv : BART Random Variable
790
790
BART variable once the model that include it has been fitted.
@@ -949,8 +949,10 @@ def compute_variable_importance( # noqa: PLR0915 PLR0912
949
949
950
950
indices = least_important_vars [::- 1 ]
951
951
952
+ labels = np .array (["+ " + ele if index != 0 else ele for index , ele in enumerate (labels )])
953
+
952
954
vi_results = {
953
- "indices" : indices ,
955
+ "indices" : np . asarray ( indices ) ,
954
956
"labels" : labels [indices ],
955
957
"r2_mean" : r2_mean ,
956
958
"r2_hdi" : r2_hdi ,
@@ -962,8 +964,9 @@ def compute_variable_importance( # noqa: PLR0915 PLR0912
962
964
963
965
def plot_variable_importance (
964
966
vi_results : dict ,
965
- labels = None ,
966
- figsize = None ,
967
+ submodels : Optional [Union [list [int ], np .ndarray , tuple [int , ...]]] = None ,
968
+ labels : Optional [list [str ]] = None ,
969
+ figsize : Optional [tuple [float , float ]] = None ,
967
970
plot_kwargs : Optional [dict [str , Any ]] = None ,
968
971
ax : Optional [plt .Axes ] = None ,
969
972
):
@@ -974,8 +977,11 @@ def plot_variable_importance(
974
977
----------
975
978
vi_results: Dictionary
976
979
Dictionary computed with `compute_variable_importance`
977
- X : npt.NDArray[np.float64]
978
- The covariate matrix.
980
+ submodels : Optional[Union[list[int], np.ndarray]]
981
+ List of the indices of the submodels to plot. Defaults to None, all variables are ploted.
982
+ The indices correspond to order computed by `compute_variable_importance`.
983
+ For example `submodels=[0,1]` will plot the two most important variables.
984
+ `submodels=[1,0]` is equivalent as values are sorted before use.
979
985
labels : Optional[list[str]]
980
986
List of the names of the covariates. If X is a DataFrame the names of the covariables will
981
987
be taken from it and this argument will be ignored.
@@ -995,11 +1001,15 @@ def plot_variable_importance(
995
1001
-------
996
1002
axes: matplotlib axes
997
1003
"""
1004
+ if submodels is None :
1005
+ submodels = np .sort (vi_results ["indices" ])
1006
+ else :
1007
+ submodels = np .sort (submodels )
998
1008
999
- indices = vi_results ["indices" ]
1000
- r2_mean = vi_results ["r2_mean" ]
1001
- r2_hdi = vi_results ["r2_hdi" ]
1002
- preds = vi_results ["preds" ]
1009
+ indices = vi_results ["indices" ][ submodels ]
1010
+ r2_mean = vi_results ["r2_mean" ][ submodels ]
1011
+ r2_hdi = vi_results ["r2_hdi" ][ submodels ]
1012
+ preds = vi_results ["preds" ][ submodels ]
1003
1013
preds_all = vi_results ["preds_all" ]
1004
1014
samples = preds .shape [1 ]
1005
1015
@@ -1016,9 +1026,7 @@ def plot_variable_importance(
1016
1026
_ , ax = plt .subplots (1 , 1 , figsize = figsize )
1017
1027
1018
1028
if labels is None :
1019
- labels = vi_results ["labels" ]
1020
-
1021
- labels = ["+ " + ele if index != 0 else ele for index , ele in enumerate (labels )]
1029
+ labels = vi_results ["labels" ][submodels ]
1022
1030
1023
1031
r_2_ref = np .array ([pearsonr2 (preds_all [j ], preds_all [j + 1 ]) for j in range (samples - 1 )])
1024
1032
@@ -1059,21 +1067,27 @@ def plot_variable_importance(
1059
1067
def plot_scatter_submodels (
1060
1068
vi_results : dict ,
1061
1069
func : Optional [Callable ] = None ,
1070
+ submodels : Optional [Union [list [int ], np .ndarray ]] = None ,
1062
1071
grid : str = "long" ,
1063
- labels = None ,
1072
+ labels : Optional [ list [ str ]] = None ,
1064
1073
figsize : Optional [tuple [float , float ]] = None ,
1065
1074
plot_kwargs : Optional [dict [str , Any ]] = None ,
1066
- axes : Optional [plt .Axes ] = None ,
1067
- ):
1075
+ ax : Optional [plt .Axes ] = None ,
1076
+ ) -> list [ plt . Axes ] :
1068
1077
"""
1069
1078
Plot submodel's predictions against reference-model's predictions.
1070
1079
1071
1080
Parameters
1072
1081
----------
1073
- vi_results: Dictionary
1082
+ vi_results : Dictionary
1074
1083
Dictionary computed with `compute_variable_importance`
1075
1084
func : Optional[Callable], by default None.
1076
1085
Arbitrary function to apply to the predictions. Defaults to the identity function.
1086
+ submodels : Optional[Union[list[int], np.ndarray]]
1087
+ List of the indices of the submodels to plot. Defaults to None, all variables are ploted.
1088
+ The indices correspond to order computed by `compute_variable_importance`.
1089
+ For example `submodels=[0,1]` will plot the two most important variables.
1090
+ `submodels=[1,0]` is equivalent as values are sorted before use.
1077
1091
grid : str or tuple
1078
1092
How to arrange the subplots. Defaults to "long", one subplot below the other.
1079
1093
Other options are "wide", one subplot next to each other or a tuple indicating the number
@@ -1092,20 +1106,23 @@ def plot_scatter_submodels(
1092
1106
-------
1093
1107
axes: matplotlib axes
1094
1108
"""
1095
- indices = vi_results ["indices" ]
1096
- preds = vi_results ["preds" ]
1109
+ if submodels is None :
1110
+ submodels = np .sort (vi_results ["indices" ])
1111
+ else :
1112
+ submodels = np .sort (submodels )
1113
+
1114
+ indices = vi_results ["indices" ][submodels ]
1115
+ preds = vi_results ["preds" ][submodels ]
1097
1116
preds_all = vi_results ["preds_all" ]
1098
1117
1099
- if axes is None :
1100
- _ , axes = _get_axes (grid , len (indices ), True , True , figsize )
1118
+ if ax is None :
1119
+ _ , ax = _get_axes (grid , len (indices ), True , True , figsize )
1101
1120
1102
1121
if plot_kwargs is None :
1103
1122
plot_kwargs = {}
1104
1123
1105
1124
if labels is None :
1106
- labels = vi_results ["labels" ]
1107
-
1108
- labels = ["+ " + ele if index != 0 else ele for index , ele in enumerate (labels )]
1125
+ labels = vi_results ["labels" ][submodels ]
1109
1126
1110
1127
if func is not None :
1111
1128
preds = func (preds )
@@ -1114,22 +1131,23 @@ def plot_scatter_submodels(
1114
1131
min_ = min (np .min (preds ), np .min (preds_all ))
1115
1132
max_ = max (np .max (preds ), np .max (preds_all ))
1116
1133
1117
- for pred , x_label , ax in zip (preds , labels , axes .ravel ()):
1118
- ax .plot (
1134
+ for pred , x_label , axi in zip (preds , labels , ax .ravel ()):
1135
+ axi .plot (
1119
1136
pred ,
1120
1137
preds_all ,
1121
1138
marker = plot_kwargs .get ("marker_scatter" , "." ),
1122
1139
ls = "" ,
1123
1140
color = plot_kwargs .get ("color_scatter" , "C0" ),
1124
1141
alpha = plot_kwargs .get ("alpha_scatter" , 0.1 ),
1125
1142
)
1126
- ax .set_xlabel (x_label )
1127
- ax .axline (
1143
+ axi .set_xlabel (x_label )
1144
+ axi .axline (
1128
1145
[min_ , min_ ],
1129
1146
[max_ , max_ ],
1130
1147
color = plot_kwargs .get ("color_ref" , "0.5" ),
1131
1148
ls = plot_kwargs .get ("ls_ref" , "--" ),
1132
1149
)
1150
+ return ax
1133
1151
1134
1152
1135
1153
def generate_sequences (n_vars , i_var , include ):
0 commit comments