Skip to content

Commit d72c963

Browse files
committed
improve docs, aesthetics and functionality
1 parent 4ef2dd0 commit d72c963

File tree

1 file changed

+67
-6
lines changed

1 file changed

+67
-6
lines changed

pymc_bart/utils.py

Lines changed: 67 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1058,25 +1058,86 @@ def plot_variable_importance(
10581058
return ax
10591059

10601060

1061-
def plot_scatter_submodels(vi_results, func=None, grid="long", axes=None):
1061+
def plot_scatter_submodels(
1062+
vi_results: dict,
1063+
X: npt.NDArray[np.float64],
1064+
func: Optional[Callable] = None,
1065+
grid: str = "long",
1066+
figsize: Optional[Tuple[float, float]] = None,
1067+
plot_kwargs: Optional[Dict[str, Any]] = None,
1068+
axes: Optional[plt.Axes] = None,
1069+
):
1070+
"""
1071+
Plot submodel's predictions against reference-model's predictions.
1072+
1073+
Parameters
1074+
----------
1075+
vi_results: Dictionary
1076+
Dictionary computed with `compute_variable_importance`
1077+
X : npt.NDArray[np.float64]
1078+
The covariate matrix.
1079+
func : Optional[Callable], by default None.
1080+
Arbitrary function to apply to the predictions. Defaults to the identity function.
1081+
grid : str or tuple
1082+
How to arrange the subplots. Defaults to "long", one subplot below the other.
1083+
Other options are "wide", one subplot next to each other or a tuple indicating the number
1084+
of rows and columns.
1085+
plot_kwargs : dict
1086+
Additional keyword arguments for the plot. Defaults to None.
1087+
Valid keys are:
1088+
- color_ref: matplotlib valid color for the 45 degree line
1089+
- color_scatter: matplotlib valid color for the scatter plot
1090+
axes : axes
1091+
Matplotlib axes.
1092+
1093+
Returns
1094+
-------
1095+
axes: matplotlib axes
1096+
"""
10621097
indices = vi_results["indices"]
10631098
preds = vi_results["preds"]
10641099
preds_all = vi_results["preds_all"]
1100+
n_vars = len(indices)
10651101

10661102
if axes is None:
1067-
_, axes = _get_axes(grid, len(indices), False, True, None)
1103+
_, axes = _get_axes(grid, len(indices), True, True, figsize)
1104+
1105+
if plot_kwargs is None:
1106+
plot_kwargs = {}
1107+
1108+
if hasattr(X, "columns") and hasattr(X, "to_numpy"):
1109+
labels = X.columns
1110+
1111+
if labels is None:
1112+
labels = np.arange(n_vars).astype(str)
1113+
else:
1114+
labels = np.asarray(labels)
1115+
1116+
new_labels = ["+ " + ele if index != 0 else ele for index, ele in enumerate(labels[indices])]
10681117

1069-
func = None
10701118
if func is not None:
10711119
preds = func(preds)
10721120
preds_all = func(preds_all)
10731121

10741122
min_ = min(np.min(preds), np.min(preds_all))
10751123
max_ = max(np.max(preds), np.max(preds_all))
10761124

1077-
for pred, ax in zip(preds, axes.ravel()):
1078-
ax.plot(pred, preds_all, ".", color="C0", alpha=0.1)
1079-
ax.axline([min_, min_], [max_, max_], color="0.5")
1125+
for pred, x_label, ax in zip(preds, new_labels, axes.ravel()):
1126+
ax.plot(
1127+
pred,
1128+
preds_all,
1129+
marker=plot_kwargs.get("marker_scatter", "."),
1130+
ls="",
1131+
color=plot_kwargs.get("color_scatter", "C0"),
1132+
alpha=plot_kwargs.get("alpha_scatter", 0.1),
1133+
)
1134+
ax.set_xlabel(x_label)
1135+
ax.axline(
1136+
[min_, min_],
1137+
[max_, max_],
1138+
color=plot_kwargs.get("color_ref", "0.5"),
1139+
ls=plot_kwargs.get("ls_ref", "--"),
1140+
)
10801141

10811142

10821143
def generate_sequences(n_vars, i_var, include):

0 commit comments

Comments
 (0)