@@ -705,7 +705,7 @@ def plot_variable_inclusion(idata, X, labels=None, figsize=None, plot_kwargs=Non
705705
706706 Parameters
707707 ----------
708- idata: InferenceData
708+ idata : InferenceData
709709 InferenceData containing a collection of BART_trees in sample_stats group
710710 X : npt.NDArray[np.float64]
711711 The covariate matrix.
@@ -784,7 +784,7 @@ def compute_variable_importance( # noqa: PLR0915 PLR0912
784784
785785 Parameters
786786 ----------
787- idata: InferenceData
787+ idata : InferenceData
788788 InferenceData containing a collection of BART_trees in sample_stats group
789789 bartrv : BART Random Variable
790790 BART variable once the model that include it has been fitted.
@@ -949,8 +949,10 @@ def compute_variable_importance( # noqa: PLR0915 PLR0912
949949
950950 indices = least_important_vars [::- 1 ]
951951
952+ labels = np .array (["+ " + ele if index != 0 else ele for index , ele in enumerate (labels )])
953+
952954 vi_results = {
953- "indices" : indices ,
955+ "indices" : np . asarray ( indices ) ,
954956 "labels" : labels [indices ],
955957 "r2_mean" : r2_mean ,
956958 "r2_hdi" : r2_hdi ,
@@ -962,8 +964,9 @@ def compute_variable_importance( # noqa: PLR0915 PLR0912
962964
963965def plot_variable_importance (
964966 vi_results : dict ,
965- labels = None ,
966- figsize = None ,
967+ submodels : Optional [Union [list [int ], np .ndarray , tuple [int , ...]]] = None ,
968+ labels : Optional [list [str ]] = None ,
969+ figsize : Optional [tuple [float , float ]] = None ,
967970 plot_kwargs : Optional [dict [str , Any ]] = None ,
968971 ax : Optional [plt .Axes ] = None ,
969972):
@@ -974,8 +977,11 @@ def plot_variable_importance(
974977 ----------
975978 vi_results: Dictionary
976979 Dictionary computed with `compute_variable_importance`
977- X : npt.NDArray[np.float64]
978- The covariate matrix.
980+ submodels : Optional[Union[list[int], np.ndarray]]
981+ List of the indices of the submodels to plot. Defaults to None, all variables are ploted.
982+ The indices correspond to order computed by `compute_variable_importance`.
983+ For example `submodels=[0,1]` will plot the two most important variables.
984+ `submodels=[1,0]` is equivalent as values are sorted before use.
979985 labels : Optional[list[str]]
980986 List of the names of the covariates. If X is a DataFrame the names of the covariables will
981987 be taken from it and this argument will be ignored.
@@ -995,11 +1001,15 @@ def plot_variable_importance(
9951001 -------
9961002 axes: matplotlib axes
9971003 """
1004+ if submodels is None :
1005+ submodels = np .sort (vi_results ["indices" ])
1006+ else :
1007+ submodels = np .sort (submodels )
9981008
999- indices = vi_results ["indices" ]
1000- r2_mean = vi_results ["r2_mean" ]
1001- r2_hdi = vi_results ["r2_hdi" ]
1002- preds = vi_results ["preds" ]
1009+ indices = vi_results ["indices" ][ submodels ]
1010+ r2_mean = vi_results ["r2_mean" ][ submodels ]
1011+ r2_hdi = vi_results ["r2_hdi" ][ submodels ]
1012+ preds = vi_results ["preds" ][ submodels ]
10031013 preds_all = vi_results ["preds_all" ]
10041014 samples = preds .shape [1 ]
10051015
@@ -1016,9 +1026,7 @@ def plot_variable_importance(
10161026 _ , ax = plt .subplots (1 , 1 , figsize = figsize )
10171027
10181028 if labels is None :
1019- labels = vi_results ["labels" ]
1020-
1021- labels = ["+ " + ele if index != 0 else ele for index , ele in enumerate (labels )]
1029+ labels = vi_results ["labels" ][submodels ]
10221030
10231031 r_2_ref = np .array ([pearsonr2 (preds_all [j ], preds_all [j + 1 ]) for j in range (samples - 1 )])
10241032
@@ -1059,21 +1067,27 @@ def plot_variable_importance(
10591067def plot_scatter_submodels (
10601068 vi_results : dict ,
10611069 func : Optional [Callable ] = None ,
1070+ submodels : Optional [Union [list [int ], np .ndarray ]] = None ,
10621071 grid : str = "long" ,
1063- labels = None ,
1072+ labels : Optional [ list [ str ]] = None ,
10641073 figsize : Optional [tuple [float , float ]] = None ,
10651074 plot_kwargs : Optional [dict [str , Any ]] = None ,
1066- axes : Optional [plt .Axes ] = None ,
1067- ):
1075+ ax : Optional [plt .Axes ] = None ,
1076+ ) -> list [ plt . Axes ] :
10681077 """
10691078 Plot submodel's predictions against reference-model's predictions.
10701079
10711080 Parameters
10721081 ----------
1073- vi_results: Dictionary
1082+ vi_results : Dictionary
10741083 Dictionary computed with `compute_variable_importance`
10751084 func : Optional[Callable], by default None.
10761085 Arbitrary function to apply to the predictions. Defaults to the identity function.
1086+ submodels : Optional[Union[list[int], np.ndarray]]
1087+ List of the indices of the submodels to plot. Defaults to None, all variables are ploted.
1088+ The indices correspond to order computed by `compute_variable_importance`.
1089+ For example `submodels=[0,1]` will plot the two most important variables.
1090+ `submodels=[1,0]` is equivalent as values are sorted before use.
10771091 grid : str or tuple
10781092 How to arrange the subplots. Defaults to "long", one subplot below the other.
10791093 Other options are "wide", one subplot next to each other or a tuple indicating the number
@@ -1092,20 +1106,23 @@ def plot_scatter_submodels(
10921106 -------
10931107 axes: matplotlib axes
10941108 """
1095- indices = vi_results ["indices" ]
1096- preds = vi_results ["preds" ]
1109+ if submodels is None :
1110+ submodels = np .sort (vi_results ["indices" ])
1111+ else :
1112+ submodels = np .sort (submodels )
1113+
1114+ indices = vi_results ["indices" ][submodels ]
1115+ preds = vi_results ["preds" ][submodels ]
10971116 preds_all = vi_results ["preds_all" ]
10981117
1099- if axes is None :
1100- _ , axes = _get_axes (grid , len (indices ), True , True , figsize )
1118+ if ax is None :
1119+ _ , ax = _get_axes (grid , len (indices ), True , True , figsize )
11011120
11021121 if plot_kwargs is None :
11031122 plot_kwargs = {}
11041123
11051124 if labels is None :
1106- labels = vi_results ["labels" ]
1107-
1108- labels = ["+ " + ele if index != 0 else ele for index , ele in enumerate (labels )]
1125+ labels = vi_results ["labels" ][submodels ]
11091126
11101127 if func is not None :
11111128 preds = func (preds )
@@ -1114,22 +1131,23 @@ def plot_scatter_submodels(
11141131 min_ = min (np .min (preds ), np .min (preds_all ))
11151132 max_ = max (np .max (preds ), np .max (preds_all ))
11161133
1117- for pred , x_label , ax in zip (preds , labels , axes .ravel ()):
1118- ax .plot (
1134+ for pred , x_label , axi in zip (preds , labels , ax .ravel ()):
1135+ axi .plot (
11191136 pred ,
11201137 preds_all ,
11211138 marker = plot_kwargs .get ("marker_scatter" , "." ),
11221139 ls = "" ,
11231140 color = plot_kwargs .get ("color_scatter" , "C0" ),
11241141 alpha = plot_kwargs .get ("alpha_scatter" , 0.1 ),
11251142 )
1126- ax .set_xlabel (x_label )
1127- ax .axline (
1143+ axi .set_xlabel (x_label )
1144+ axi .axline (
11281145 [min_ , min_ ],
11291146 [max_ , max_ ],
11301147 color = plot_kwargs .get ("color_ref" , "0.5" ),
11311148 ls = plot_kwargs .get ("ls_ref" , "--" ),
11321149 )
1150+ return ax
11331151
11341152
11351153def generate_sequences (n_vars , i_var , include ):
0 commit comments