@@ -492,7 +492,7 @@ def _create_figure_axes(
492
492
n_plots = len (var_idx ) * shape
493
493
494
494
if ax is None :
495
- axes = _get_axes (grid , n_plots , False , sharey , figsize )
495
+ fig , axes = _get_axes (grid , n_plots , False , sharey , figsize )
496
496
497
497
elif isinstance (ax , np .ndarray ):
498
498
axes = ax
@@ -528,7 +528,7 @@ def _get_axes(grid, n_plots, sharex, sharey, figsize):
528
528
for i in range (n_plots , len (axes )):
529
529
fig .delaxes (axes [i ])
530
530
axes = axes [:n_plots ]
531
- return axes
531
+ return fig , axes
532
532
533
533
534
534
def _prepare_plot_data (
@@ -958,7 +958,7 @@ def compute_variable_importance( # noqa: PLR0915 PLR0912
958
958
def plot_variable_importance (
959
959
vi_results : dict ,
960
960
X : npt .NDArray [np .float64 ],
961
- labels : Optional [ List [ str ]] = None ,
961
+ labels = None ,
962
962
figsize = None ,
963
963
plot_kwargs : Optional [Dict [str , Any ]] = None ,
964
964
ax : Optional [plt .Axes ] = None ,
@@ -1016,7 +1016,9 @@ def plot_variable_importance(
1016
1016
_ , ax = plt .subplots (1 , 1 , figsize = figsize )
1017
1017
1018
1018
if labels is None :
1019
- labels = list (np .arange (n_vars ).astype (str ))
1019
+ labels = np .arange (n_vars ).astype (str )
1020
+ else :
1021
+ labels = np .asarray (labels )
1020
1022
1021
1023
new_labels = ["+ " + ele if index != 0 else ele for index , ele in enumerate (labels [indices ])]
1022
1024
@@ -1062,7 +1064,7 @@ def plot_scatter_submodels(vi_results, func=None, grid="long", axes=None):
1062
1064
preds_all = vi_results ["preds_all" ]
1063
1065
1064
1066
if axes is None :
1065
- axes = _get_axes (grid , len (indices ), False , True , None )
1067
+ _ , axes = _get_axes (grid , len (indices ), False , True , None )
1066
1068
1067
1069
func = None
1068
1070
if func is not None :
0 commit comments