Skip to content

Commit a1adedf

Browse files
authored
Fix bug in plot_ice, and clean docstring of plot_ice and plot_pdp (#135)
* fix plot_pdp/ice * fix test * fix type hints
1 parent 1d2287e commit a1adedf

File tree

2 files changed

+18
-29
lines changed

2 files changed

+18
-29
lines changed

pymc_bart/utils.py

Lines changed: 18 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -155,14 +155,12 @@ def plot_ice(
155155
bartrv: Variable,
156156
X: npt.NDArray[np.float_],
157157
Y: Optional[npt.NDArray[np.float_]] = None,
158-
xs_interval: str = "quantiles",
159-
xs_values: Optional[Union[int, List[float]]] = None,
160158
var_idx: Optional[List[int]] = None,
161159
var_discrete: Optional[List[int]] = None,
162160
func: Optional[Callable] = None,
163161
centered: Optional[bool] = True,
164-
samples: int = 50,
165-
instances: int = 10,
162+
samples: int = 100,
163+
instances: int = 30,
166164
random_seed: Optional[int] = None,
167165
sharey: bool = True,
168166
smooth: bool = True,
@@ -185,16 +183,6 @@ def plot_ice(
185183
The covariate matrix.
186184
Y : Optional[npt.NDArray[np.float_]], by default None.
187185
The response vector.
188-
xs_interval : str
189-
Method used to compute the values X used to evaluate the predicted function. "linear",
190-
evenly spaced values in the range of X. "quantiles", the evaluation is done at the specified
191-
quantiles of X. "insample", the evaluation is done at the values of X.
192-
For discrete variables these options are ommited.
193-
xs_values : Optional[Union[int, List[float]]], by default None.
194-
Values of X used to evaluate the predicted function. If ``xs_interval="linear"`` number of
195-
points in the evenly spaced grid. If ``xs_interval="quantiles"``quantile or sequence of
196-
quantiles to compute, which must be between 0 and 1 inclusive.
197-
Ignored when ``xs_interval="insample"``.
198186
var_idx : Optional[List[int]], by default None.
199187
List of the indices of the covariate for which to compute the pdp or ice.
200188
var_discrete : Optional[List[int]], by default None.
@@ -205,22 +193,20 @@ def plot_ice(
205193
If True the result is centered around the partial response evaluated at the lowest value in
206194
``xs_interval``. Defaults to True.
207195
samples : int
208-
Number of posterior samples used in the predictions. Defaults to 50
196+
Number of posterior samples used in the predictions. Defaults to 100
209197
instances : int
210-
Number of instances of X to plot. Defaults to 10.
198+
Number of instances of X to plot. Defaults to 30.
211199
random_seed : Optional[int], by default None.
212200
Seed used to sample from the posterior. Defaults to None.
213201
sharey : bool
214202
Controls sharing of properties among y-axes. Defaults to True.
215-
rug : bool
216-
Whether to include a rugplot. Defaults to True.
217203
smooth : bool
218204
If True the result will be smoothed by first computing a linear interpolation of the data
219205
over a regular grid and then applying the Savitzky-Golay filter to the interpolated data.
220206
Defaults to True.
221207
grid : str or tuple
222208
How to arrange the subplots. Defaults to "long", one subplot below the other.
223-
Other options are "wide", one subplot next to eachother or a tuple indicating the number of
209+
Other options are "wide", one subplot next to each other or a tuple indicating the number of
224210
rows and columns.
225211
color : matplotlib valid color
226212
Color used to plot the pdp or ice. Defaults to "C0"
@@ -257,17 +243,17 @@ def identity(x):
257243
indices,
258244
var_idx,
259245
var_discrete,
260-
xs_interval,
261-
xs_values,
262-
) = _prepare_plot_data(X, Y, xs_interval, xs_values, var_idx, var_discrete)
246+
_,
247+
_,
248+
) = _prepare_plot_data(X, Y, "linear", None, var_idx, var_discrete)
263249

264250
fig, axes, shape = _get_axes(bartrv, var_idx, grid, sharey, figsize, ax)
265251

266252
instances_ary = rng.choice(range(X.shape[0]), replace=False, size=instances)
267253
idx_s = list(range(X.shape[0]))
268254

269255
count = 0
270-
for var in range(len(var_idx)):
256+
for i_var, var in enumerate(var_idx):
271257
indices_mi = indices[:]
272258
indices_mi.remove(var)
273259
y_pred = []
@@ -283,6 +269,7 @@ def identity(x):
283269

284270
new_x = fake_X[:, var]
285271
p_d = np.array(y_pred)
272+
print(p_d.shape)
286273

287274
for s_i in range(shape):
288275
if centered:
@@ -301,7 +288,7 @@ def identity(x):
301288
idx = np.argsort(new_x)
302289
axes[count].plot(new_x[idx], p_di.mean(0)[idx], color=color_mean)
303290
axes[count].plot(new_x[idx], p_di.T[idx], color=color, alpha=alpha)
304-
axes[count].set_xlabel(x_labels[var])
291+
axes[count].set_xlabel(x_labels[i_var])
305292

306293
count += 1
307294

@@ -349,7 +336,7 @@ def plot_pdp(
349336
For discrete variables these options are ommited.
350337
xs_values : Optional[Union[int, List[float]]], by default None.
351338
Values of X used to evaluate the predicted function. If ``xs_interval="linear"`` number of
352-
points in the evenly spaced grid. If ``xs_interval="quantiles"``quantile or sequence of
339+
points in the evenly spaced grid. If ``xs_interval="quantiles"`` quantile or sequence of
353340
quantiles to compute, which must be between 0 and 1 inclusive.
354341
Ignored when ``xs_interval="insample"``.
355342
var_idx : Optional[List[int]], by default None.
@@ -717,7 +704,8 @@ def plot_variable_importance(
717704
xlabel_angle: float = 0,
718705
samples: int = 100,
719706
random_seed: Optional[int] = None,
720-
) -> Tuple[List[int], List[plt.Axes]]:
707+
ax: Optional[plt.Axes] = None,
708+
) -> Tuple[List[int], Union[List[plt.Axes], Any]]:
721709
"""
722710
Estimates variable importance from the BART-posterior.
723711
@@ -747,6 +735,8 @@ def plot_variable_importance(
747735
Number of predictions used to compute correlation for subsets of variables. Defaults to 100
748736
random_seed : Optional[int]
749737
random_seed used to sample from the posterior. Defaults to None.
738+
ax : axes
739+
Matplotlib axes.
750740
751741
Returns
752742
-------
@@ -771,7 +761,8 @@ def plot_variable_importance(
771761
if figsize is None:
772762
figsize = (8, 3)
773763

774-
_, ax = plt.subplots(1, 1, figsize=figsize)
764+
if ax is None:
765+
_, ax = plt.subplots(1, 1, figsize=figsize)
775766

776767
if labels is None:
777768
labels_ary = np.arange(n_vars).astype(str)

tests/test_bart.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -153,8 +153,6 @@ def test_sample_posterior(self):
153153
{},
154154
{
155155
"samples": 2,
156-
"xs_interval": "quantiles",
157-
"xs_values": [0.25, 0.5, 0.75],
158156
"var_discrete": [3],
159157
},
160158
{"instances": 2},

0 commit comments

Comments
 (0)