Skip to content

Commit c2a4f4a

Browse files
authored
Better handling of discrete variables and other minor fixes (#121)
* improve discrete variable pdp and minor fixes * black * pin mypy * fix types
1 parent 1b8ded4 commit c2a4f4a

File tree

3 files changed

+46
-38
lines changed

3 files changed

+46
-38
lines changed

pymc_bart/pgbart.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -96,10 +96,10 @@ class PGBART(ArrayStepShared):
9696
List of value variables for sampler
9797
num_particles : tuple
9898
Number of particles. Defaults to 10
99-
batch : int or tuple
100-
Number of trees fitted per step. Defaults to "auto", which is the 10% of the `m` trees
101-
during tuning and after tuning. If a tuple is passed the first element is the batch size
102-
during tuning and the second the batch size after tuning.
99+
batch : tuple
100+
Number of trees fitted per step. The first element is the batch size during tuning and the
101+
second the batch size after tuning. Defaults to (0.1, 0.1), meaning 10% of the `m` trees
102+
during tuning and after tuning.
103103
model: PyMC Model
104104
Optional model for sampling step. Defaults to None (taken from context).
105105
"""

pymc_bart/utils.py

Lines changed: 41 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -157,7 +157,7 @@ def plot_ice(
157157
bartrv: Variable,
158158
X: npt.NDArray[np.float_],
159159
Y: Optional[npt.NDArray[np.float_]] = None,
160-
xs_interval: str = "linear",
160+
xs_interval: str = "quantiles",
161161
xs_values: Optional[Union[int, List[float]]] = None,
162162
var_idx: Optional[List[int]] = None,
163163
var_discrete: Optional[List[int]] = None,
@@ -303,7 +303,7 @@ def identity(x):
303303
idx = np.argsort(new_x)
304304
axes[count].plot(new_x[idx], p_di.mean(0)[idx], color=color_mean)
305305
axes[count].plot(new_x[idx], p_di.T[idx], color=color, alpha=alpha)
306-
axes[count].set_xlabel(x_labels[var])
306+
axes[count].set_xlabel(x_labels[var])
307307

308308
count += 1
309309

@@ -316,7 +316,7 @@ def plot_pdp(
316316
bartrv: Variable,
317317
X: npt.NDArray[np.float_],
318318
Y: Optional[npt.NDArray[np.float_]] = None,
319-
xs_interval: str = "linear",
319+
xs_interval: str = "quantiles",
320320
xs_values: Optional[Union[int, List[float]]] = None,
321321
var_idx: Optional[List[int]] = None,
322322
var_discrete: Optional[List[int]] = None,
@@ -423,35 +423,39 @@ def identity(x):
423423
p_d = _sample_posterior(
424424
all_trees, X=fake_X, rng=rng, size=samples, excluded=excluded, shape=shape
425425
)
426-
new_x = fake_X[:, var]
427-
for s_i in range(shape):
428-
p_di = func(p_d[:, :, s_i])
429-
if var in var_discrete:
430-
y_means = p_di.mean(0)
431-
hdi = az.hdi(p_di)
432-
axes[count].errorbar(
433-
new_x,
434-
y_means,
435-
(y_means - hdi[:, 0], hdi[:, 1] - y_means),
436-
fmt=".",
437-
color=color,
438-
)
439-
else:
440-
az.plot_hdi(
441-
new_x,
442-
p_di,
443-
smooth=smooth,
444-
fill_kwargs={"alpha": alpha, "color": color},
445-
ax=axes[count],
446-
)
447-
if smooth:
448-
x_data, y_data = _smooth_mean(new_x, p_di, "pdp", smooth_kwargs)
449-
axes[count].plot(x_data, y_data, color=color_mean)
426+
with warnings.catch_warnings():
427+
warnings.filterwarnings("ignore", message="hdi currently interprets 2d data")
428+
new_x = fake_X[:, var]
429+
for s_i in range(shape):
430+
p_di = func(p_d[:, :, s_i])
431+
if var in var_discrete:
432+
_, idx_uni = np.unique(new_x, return_index=True)
433+
y_means = p_di.mean(0)[idx_uni]
434+
hdi = az.hdi(p_di)[idx_uni]
435+
axes[count].errorbar(
436+
new_x[idx_uni],
437+
y_means,
438+
(y_means - hdi[:, 0], hdi[:, 1] - y_means),
439+
fmt=".",
440+
color=color,
441+
)
442+
axes[count].set_xticks(new_x[idx_uni])
450443
else:
451-
axes[count].plot(new_x, p_di.mean(0), color=color_mean)
444+
az.plot_hdi(
445+
new_x,
446+
p_di,
447+
smooth=smooth,
448+
fill_kwargs={"alpha": alpha, "color": color},
449+
ax=axes[count],
450+
)
451+
if smooth:
452+
x_data, y_data = _smooth_mean(new_x, p_di, "pdp", smooth_kwargs)
453+
axes[count].plot(x_data, y_data, color=color_mean)
454+
else:
455+
axes[count].plot(new_x, p_di.mean(0), color=color_mean)
452456
axes[count].set_xlabel(x_labels[var])
453457

454-
count += 1
458+
count += 1
455459

456460
fig.text(-0.05, 0.5, y_label, va="center", rotation="vertical", fontsize=15)
457461

@@ -527,16 +531,20 @@ def _get_axes(
527531
fig.delaxes(axes[i])
528532
axes = axes[:n_plots]
529533
else:
530-
axes = [ax]
531-
fig = ax.get_figure()
534+
if isinstance(ax, np.ndarray):
535+
axes = ax
536+
fig = ax[0].get_figure()
537+
else:
538+
axes = [ax]
539+
fig = ax.get_figure() # type: ignore
532540

533541
return fig, axes, shape
534542

535543

536544
def _prepare_plot_data(
537545
X: npt.NDArray[np.float_],
538546
Y: Optional[npt.NDArray[np.float_]] = None,
539-
xs_interval: str = "linear",
547+
xs_interval: str = "quantiles",
540548
xs_values: Optional[Union[int, List[float]]] = None,
541549
var_idx: Optional[List[int]] = None,
542550
var_discrete: Optional[List[int]] = None,
@@ -710,7 +718,7 @@ def plot_variable_importance(
710718
figsize: Optional[Tuple[float, float]] = None,
711719
samples: int = 100,
712720
random_seed: Optional[int] = None,
713-
) -> Tuple[npt.NDArray[np.int_], List[plt.axes]]:
721+
) -> Tuple[npt.NDArray[np.int_], List[plt.Axes]]:
714722
"""
715723
Estimates variable importance from the BART-posterior.
716724

requirements-dev.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
black==22.3.0
22
click==8.0.4
3-
mypy>=1.1.1
3+
mypy==1.3.0
44
pandas-stubs==1.5.3.230304
55
pre-commit
66
pylint==2.17.4

0 commit comments

Comments
 (0)