Skip to content

Commit d41e239

Browse files
committed
Expand scatter_submodels to categorical likelihood
1 parent 2c00358 commit d41e239

File tree

1 file changed

+58
-25
lines changed

1 file changed

+58
-25
lines changed

pymc_bart/utils.py

Lines changed: 58 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -1125,8 +1125,11 @@ def plot_scatter_submodels(
11251125
plot_kwargs : dict
11261126
Additional keyword arguments for the plot. Defaults to None.
11271127
Valid keys are:
1128-
- color_ref: matplotlib valid color for the 45 degree line
1128+
- marker_scatter: matplotlib valid marker for the scatter plot
11291129
- color_scatter: matplotlib valid color for the scatter plot
1130+
- alpha_scatter: matplotlib valid alpha for the scatter plot
1131+
- color_ref: matplotlib valid color for the 45 degree line
1132+
- ls_ref: matplotlib valid linestyle for the reference line
11301133
axes : axes
11311134
Matplotlib axes.
11321135
@@ -1140,41 +1143,71 @@ def plot_scatter_submodels(
11401143
submodels = np.sort(submodels)
11411144

11421145
indices = vi_results["indices"][submodels]
1143-
preds = vi_results["preds"][submodels]
1146+
preds_sub = vi_results["preds"][submodels]
11441147
preds_all = vi_results["preds_all"]
11451148

1149+
if labels is None:
1150+
labels = vi_results["labels"][submodels]
1151+
1152+
# handle categorical regression case:
1153+
n_cats = None
1154+
if preds_all.ndim > 2:
1155+
n_cats = preds_all.shape[-1]
1156+
indices = np.tile(indices, n_cats)
1157+
# labels = np.tile(labels, n_cats)
1158+
# cats = np.repeat(np.arange(n_cats), len(indices) // n_cats)
1159+
11461160
if ax is None:
11471161
_, ax = _get_axes(grid, len(indices), True, True, figsize)
11481162

11491163
if plot_kwargs is None:
11501164
plot_kwargs = {}
11511165

1152-
if labels is None:
1153-
labels = vi_results["labels"][submodels]
1154-
11551166
if func is not None:
1156-
preds = func(preds)
1167+
preds_sub = func(preds_sub)
11571168
preds_all = func(preds_all)
11581169

1159-
min_ = min(np.min(preds), np.min(preds_all))
1160-
max_ = max(np.max(preds), np.max(preds_all))
1161-
1162-
for pred, x_label, axi in zip(preds, labels, ax.ravel()):
1163-
axi.plot(
1164-
pred,
1165-
preds_all,
1166-
marker=plot_kwargs.get("marker_scatter", "."),
1167-
ls="",
1168-
color=plot_kwargs.get("color_scatter", "C0"),
1169-
alpha=plot_kwargs.get("alpha_scatter", 0.1),
1170-
)
1171-
axi.set_xlabel(x_label)
1172-
axi.axline(
1173-
[min_, min_],
1174-
[max_, max_],
1175-
color=plot_kwargs.get("color_ref", "0.5"),
1176-
ls=plot_kwargs.get("ls_ref", "--"),
1177-
)
1170+
min_ = min(np.min(preds_sub), np.min(preds_all))
1171+
max_ = max(np.max(preds_sub), np.max(preds_all))
1172+
1173+
# handle categorical regression case:
1174+
if n_cats is not None:
1175+
i = 0
1176+
for cat in range(n_cats):
1177+
for pred_sub, x_label in zip(preds_sub, labels):
1178+
ax[i].plot(
1179+
pred_sub[..., cat],
1180+
preds_all[..., cat],
1181+
marker=plot_kwargs.get("marker_scatter", "."),
1182+
ls="",
1183+
color=plot_kwargs.get("color_scatter", f"C{cat}"),
1184+
alpha=plot_kwargs.get("alpha_scatter", 0.1),
1185+
)
1186+
ax[i].set(xlabel=x_label, ylabel="ref model", title=f"Category {cat}")
1187+
ax[i].axline(
1188+
[min_, min_],
1189+
[max_, max_],
1190+
color=plot_kwargs.get("color_ref", "0.5"),
1191+
ls=plot_kwargs.get("ls_ref", "--"),
1192+
)
1193+
i += 1
1194+
else:
1195+
for pred_sub, x_label, axi in zip(preds_sub, labels, ax.ravel()):
1196+
axi.plot(
1197+
pred_sub,
1198+
preds_all,
1199+
marker=plot_kwargs.get("marker_scatter", "."),
1200+
ls="",
1201+
color=plot_kwargs.get("color_scatter", "C0"),
1202+
alpha=plot_kwargs.get("alpha_scatter", 0.1),
1203+
)
1204+
axi.set(xlabel=x_label, ylabel="ref model")
1205+
axi.axline(
1206+
[min_, min_],
1207+
[max_, max_],
1208+
color=plot_kwargs.get("color_ref", "0.5"),
1209+
ls=plot_kwargs.get("ls_ref", "--"),
1210+
)
11781211
return ax
11791212

11801213

0 commit comments

Comments
 (0)