Skip to content

Commit 38bf6b0

Browse files
committed
remove X argument from plots
1 parent d72c963 commit 38bf6b0

File tree

2 files changed

+17
-25
lines changed

2 files changed

+17
-25
lines changed

pymc_bart/utils.py

Lines changed: 15 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -824,10 +824,14 @@ def compute_variable_importance( # noqa: PLR0915 PLR0912
824824
else:
825825
shape = bartrv.eval().shape[0]
826826

827+
n_vars = X.shape[1]
828+
827829
if hasattr(X, "columns") and hasattr(X, "to_numpy"):
830+
labels = X.columns
828831
X = X.to_numpy()
832+
else:
833+
labels = np.arange(n_vars).astype(str)
829834

830-
n_vars = X.shape[1]
831835
r2_mean = np.zeros(n_vars)
832836
r2_hdi = np.zeros((n_vars, 2))
833837
preds = np.zeros((n_vars, samples, bartrv.eval().shape[0]))
@@ -947,6 +951,7 @@ def compute_variable_importance( # noqa: PLR0915 PLR0912
947951

948952
vi_results = {
949953
"indices": indices,
954+
"labels": labels[indices],
950955
"r2_mean": r2_mean,
951956
"r2_hdi": r2_hdi,
952957
"preds": preds,
@@ -957,7 +962,6 @@ def compute_variable_importance( # noqa: PLR0915 PLR0912
957962

958963
def plot_variable_importance(
959964
vi_results: dict,
960-
X: npt.NDArray[np.float64],
961965
labels=None,
962966
figsize=None,
963967
plot_kwargs: Optional[Dict[str, Any]] = None,
@@ -1008,19 +1012,13 @@ def plot_variable_importance(
10081012
if figsize is None:
10091013
figsize = (8, 3)
10101014

1011-
if hasattr(X, "columns") and hasattr(X, "to_numpy"):
1012-
labels = X.columns
1013-
X = X.to_numpy()
1014-
10151015
if ax is None:
10161016
_, ax = plt.subplots(1, 1, figsize=figsize)
10171017

10181018
if labels is None:
1019-
labels = np.arange(n_vars).astype(str)
1020-
else:
1021-
labels = np.asarray(labels)
1019+
labels = vi_results["labels"]
10221020

1023-
new_labels = ["+ " + ele if index != 0 else ele for index, ele in enumerate(labels[indices])]
1021+
labels = ["+ " + ele if index != 0 else ele for index, ele in enumerate(labels)]
10241022

10251023
r_2_ref = np.array([pearsonr2(preds_all[j], preds_all[j + 1]) for j in range(samples - 1)])
10261024

@@ -1048,7 +1046,7 @@ def plot_variable_importance(
10481046
)
10491047
ax.set_xticks(
10501048
ticks,
1051-
new_labels,
1049+
labels,
10521050
rotation=plot_kwargs.get("rotation", 0),
10531051
)
10541052
ax.set_ylabel("R²", rotation=0, labelpad=12)
@@ -1060,9 +1058,9 @@ def plot_variable_importance(
10601058

10611059
def plot_scatter_submodels(
10621060
vi_results: dict,
1063-
X: npt.NDArray[np.float64],
10641061
func: Optional[Callable] = None,
10651062
grid: str = "long",
1063+
labels=None,
10661064
figsize: Optional[Tuple[float, float]] = None,
10671065
plot_kwargs: Optional[Dict[str, Any]] = None,
10681066
axes: Optional[plt.Axes] = None,
@@ -1074,14 +1072,14 @@ def plot_scatter_submodels(
10741072
----------
10751073
vi_results: Dictionary
10761074
Dictionary computed with `compute_variable_importance`
1077-
X : npt.NDArray[np.float64]
1078-
The covariate matrix.
10791075
func : Optional[Callable], by default None.
10801076
Arbitrary function to apply to the predictions. Defaults to the identity function.
10811077
grid : str or tuple
10821078
How to arrange the subplots. Defaults to "long", one subplot below the other.
10831079
Other options are "wide", one subplot next to each other or a tuple indicating the number
10841080
of rows and columns.
1081+
labels : Optional[List[str]]
1082+
List of the names of the covariates.
10851083
plot_kwargs : dict
10861084
Additional keyword arguments for the plot. Defaults to None.
10871085
Valid keys are:
@@ -1097,23 +1095,17 @@ def plot_scatter_submodels(
10971095
indices = vi_results["indices"]
10981096
preds = vi_results["preds"]
10991097
preds_all = vi_results["preds_all"]
1100-
n_vars = len(indices)
11011098

11021099
if axes is None:
11031100
_, axes = _get_axes(grid, len(indices), True, True, figsize)
11041101

11051102
if plot_kwargs is None:
11061103
plot_kwargs = {}
11071104

1108-
if hasattr(X, "columns") and hasattr(X, "to_numpy"):
1109-
labels = X.columns
1110-
11111105
if labels is None:
1112-
labels = np.arange(n_vars).astype(str)
1113-
else:
1114-
labels = np.asarray(labels)
1106+
labels = vi_results["labels"]
11151107

1116-
new_labels = ["+ " + ele if index != 0 else ele for index, ele in enumerate(labels[indices])]
1108+
labels = ["+ " + ele if index != 0 else ele for index, ele in enumerate(labels)]
11171109

11181110
if func is not None:
11191111
preds = func(preds)
@@ -1122,7 +1114,7 @@ def plot_scatter_submodels(
11221114
min_ = min(np.min(preds), np.min(preds_all))
11231115
max_ = max(np.max(preds), np.max(preds_all))
11241116

1125-
for pred, x_label, ax in zip(preds, new_labels, axes.ravel()):
1117+
for pred, x_label, ax in zip(preds, labels, axes.ravel()):
11261118
ax.plot(
11271119
pred,
11281120
preds_all,

tests/test_bart.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -193,8 +193,8 @@ def test_vi(self, kwargs):
193193
vi_results = pmb.compute_variable_importance(
194194
self.idata, bartrv=self.mu, X=self.X, samples=samples
195195
)
196-
pmb.plot_variable_importance(vi_results, X=self.X, **kwargs)
197-
pmb.plot_scatter_submodels(vi_results)
196+
pmb.plot_variable_importance(vi_results, **kwargs)
197+
pmb.plot_scatter_submodels(vi_results, **kwargs)
198198

199199
def test_pdp_pandas_labels(self):
200200
pd = pytest.importorskip("pandas")

0 commit comments

Comments
 (0)