Skip to content

Commit 38a11f3

Browse files
committed
fix tests
1 parent 90da2cd commit 38a11f3

File tree

2 files changed

+11
-7
lines changed

2 files changed

+11
-7
lines changed

pymc_bart/__init__.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@
2424
plot_pdp,
2525
plot_scatter_submodels,
2626
plot_variable_importance,
27+
plot_variable_inclusion,
2728
)
2829

2930
__all__ = [
@@ -32,13 +33,14 @@
3233
"ContinuousSplitRule",
3334
"OneHotSplitRule",
3435
"SubsetSplitRule",
36+
"compute_variable_importance",
3537
"plot_convergence",
3638
"plot_dependence",
3739
"plot_ice",
3840
"plot_pdp",
39-
"plot_variable_importance",
40-
"compute_variable_importance",
4141
"plot_scatter_submodels",
42+
"plot_variable_importance",
43+
"plot_variable_inclusion",
4244
]
4345
__version__ = "0.7.1"
4446

pymc_bart/utils.py

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -492,7 +492,7 @@ def _create_figure_axes(
492492
n_plots = len(var_idx) * shape
493493

494494
if ax is None:
495-
axes = _get_axes(grid, n_plots, False, sharey, figsize)
495+
fig, axes = _get_axes(grid, n_plots, False, sharey, figsize)
496496

497497
elif isinstance(ax, np.ndarray):
498498
axes = ax
@@ -528,7 +528,7 @@ def _get_axes(grid, n_plots, sharex, sharey, figsize):
528528
for i in range(n_plots, len(axes)):
529529
fig.delaxes(axes[i])
530530
axes = axes[:n_plots]
531-
return axes
531+
return fig, axes
532532

533533

534534
def _prepare_plot_data(
@@ -958,7 +958,7 @@ def compute_variable_importance( # noqa: PLR0915 PLR0912
958958
def plot_variable_importance(
959959
vi_results: dict,
960960
X: npt.NDArray[np.float64],
961-
labels: Optional[List[str]] = None,
961+
labels=None,
962962
figsize=None,
963963
plot_kwargs: Optional[Dict[str, Any]] = None,
964964
ax: Optional[plt.Axes] = None,
@@ -1016,7 +1016,9 @@ def plot_variable_importance(
10161016
_, ax = plt.subplots(1, 1, figsize=figsize)
10171017

10181018
if labels is None:
1019-
labels = list(np.arange(n_vars).astype(str))
1019+
labels = np.arange(n_vars).astype(str)
1020+
else:
1021+
labels = np.asarray(labels)
10201022

10211023
new_labels = ["+ " + ele if index != 0 else ele for index, ele in enumerate(labels[indices])]
10221024

@@ -1062,7 +1064,7 @@ def plot_scatter_submodels(vi_results, func=None, grid="long", axes=None):
10621064
preds_all = vi_results["preds_all"]
10631065

10641066
if axes is None:
1065-
axes = _get_axes(grid, len(indices), False, True, None)
1067+
_, axes = _get_axes(grid, len(indices), False, True, None)
10661068

10671069
func = None
10681070
if func is not None:

0 commit comments

Comments
 (0)