Skip to content
Merged
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 23 additions & 0 deletions env-dev.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
name: pymc-bart-dev
channels:
- conda-forge
- defaults
dependencies:
- pymc>=5.16.2,<=5.19.1
- arviz>=0.18.0
- numba
- matplotlib
- numpy
- pytensor
# Development dependencies
- pytest>=4.4.0
- pytest-cov>=2.6.1
- click==8.0.4
- pylint==2.17.4
- pre-commit
- black
- isort
- flake8
- pip
- pip:
- -e .
14 changes: 14 additions & 0 deletions env.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
name: pymc-bart
channels:
- conda-forge
- defaults
dependencies:
- pymc>=5.16.2,<=5.19.1
- arviz>=0.18.0
- numba
- matplotlib
- numpy
- pytensor
- pip
- pip:
- pymc-bart
90 changes: 64 additions & 26 deletions pymc_bart/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -396,11 +396,18 @@ def identity(x):
p_d = _sample_posterior(
all_trees, X=fake_X, rng=rng, size=samples, excluded=excluded, shape=shape
)
# need to apply func to full array and to last dimension if it's softmax
if func.__name__ == "softmax":
# categories are always the last dimension
# for some reason, mypy thinks that func can be identity,
# which doesn't have the axis argument
p_d = func(p_d, axis=-1) # type: ignore[call-arg]

with warnings.catch_warnings():
warnings.filterwarnings("ignore", message="hdi currently interprets 2d data")
new_x = fake_X[:, var]
for s_i in range(shape):
p_di = func(p_d[:, :, s_i])
p_di = p_d[:, :, s_i] if func.__name__ == "softmax" else func(p_d[:, :, s_i])
null_pd.append(p_di.mean())
if var in var_discrete:
_, idx_uni = np.unique(new_x, return_index=True)
Expand Down Expand Up @@ -1125,8 +1132,11 @@ def plot_scatter_submodels(
plot_kwargs : dict
Additional keyword arguments for the plot. Defaults to None.
Valid keys are:
- color_ref: matplotlib valid color for the 45 degree line
- marker_scatter: matplotlib valid marker for the scatter plot
- color_scatter: matplotlib valid color for the scatter plot
- alpha_scatter: matplotlib valid alpha for the scatter plot
- color_ref: matplotlib valid color for the 45 degree line
- ls_ref: matplotlib valid linestyle for the reference line
axes : axes
Matplotlib axes.

Expand All @@ -1140,41 +1150,69 @@ def plot_scatter_submodels(
submodels = np.sort(submodels)

indices = vi_results["indices"][submodels]
preds = vi_results["preds"][submodels]
preds_sub = vi_results["preds"][submodels]
preds_all = vi_results["preds_all"]

if labels is None:
labels = vi_results["labels"][submodels]

# handle categorical regression case:
n_cats = None
if preds_all.ndim > 2:
n_cats = preds_all.shape[-1]
indices = np.tile(indices, n_cats)

if ax is None:
_, ax = _get_axes(grid, len(indices), True, True, figsize)

if plot_kwargs is None:
plot_kwargs = {}

if labels is None:
labels = vi_results["labels"][submodels]

if func is not None:
preds = func(preds)
preds_sub = func(preds_sub)
preds_all = func(preds_all)

min_ = min(np.min(preds), np.min(preds_all))
max_ = max(np.max(preds), np.max(preds_all))

for pred, x_label, axi in zip(preds, labels, ax.ravel()):
axi.plot(
pred,
preds_all,
marker=plot_kwargs.get("marker_scatter", "."),
ls="",
color=plot_kwargs.get("color_scatter", "C0"),
alpha=plot_kwargs.get("alpha_scatter", 0.1),
)
axi.set_xlabel(x_label)
axi.axline(
[min_, min_],
[max_, max_],
color=plot_kwargs.get("color_ref", "0.5"),
ls=plot_kwargs.get("ls_ref", "--"),
)
min_ = min(np.min(preds_sub), np.min(preds_all))
max_ = max(np.max(preds_sub), np.max(preds_all))

# handle categorical regression case:
if n_cats is not None:
i = 0
for cat in range(n_cats):
for pred_sub, x_label in zip(preds_sub, labels):
ax[i].plot(
pred_sub[..., cat],
preds_all[..., cat],
marker=plot_kwargs.get("marker_scatter", "."),
ls="",
color=plot_kwargs.get("color_scatter", f"C{cat}"),
alpha=plot_kwargs.get("alpha_scatter", 0.1),
)
ax[i].set(xlabel=x_label, ylabel="ref model", title=f"Category {cat}")
ax[i].axline(
[min_, min_],
[max_, max_],
color=plot_kwargs.get("color_ref", "0.5"),
ls=plot_kwargs.get("ls_ref", "--"),
)
i += 1
else:
for pred_sub, x_label, axi in zip(preds_sub, labels, ax.ravel()):
axi.plot(
pred_sub,
preds_all,
marker=plot_kwargs.get("marker_scatter", "."),
ls="",
color=plot_kwargs.get("color_scatter", "C0"),
alpha=plot_kwargs.get("alpha_scatter", 0.1),
)
axi.set(xlabel=x_label, ylabel="ref model")
axi.axline(
[min_, min_],
[max_, max_],
color=plot_kwargs.get("color_ref", "0.5"),
ls=plot_kwargs.get("ls_ref", "--"),
)
return ax


Expand Down