Skip to content

Commit 40f1220

Browse files
authored
improve docs, aesthetics and functionality (#198)
* improve docs, aesthetics and functionality * remove X argument from plots
1 parent 4ef2dd0 commit 40f1220

File tree

2 files changed

+72
-19
lines changed

2 files changed

+72
-19
lines changed

pymc_bart/utils.py

Lines changed: 70 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -824,10 +824,14 @@ def compute_variable_importance( # noqa: PLR0915 PLR0912
824824
else:
825825
shape = bartrv.eval().shape[0]
826826

827+
n_vars = X.shape[1]
828+
827829
if hasattr(X, "columns") and hasattr(X, "to_numpy"):
830+
labels = X.columns
828831
X = X.to_numpy()
832+
else:
833+
labels = np.arange(n_vars).astype(str)
829834

830-
n_vars = X.shape[1]
831835
r2_mean = np.zeros(n_vars)
832836
r2_hdi = np.zeros((n_vars, 2))
833837
preds = np.zeros((n_vars, samples, bartrv.eval().shape[0]))
@@ -947,6 +951,7 @@ def compute_variable_importance( # noqa: PLR0915 PLR0912
947951

948952
vi_results = {
949953
"indices": indices,
954+
"labels": labels[indices],
950955
"r2_mean": r2_mean,
951956
"r2_hdi": r2_hdi,
952957
"preds": preds,
@@ -957,7 +962,6 @@ def compute_variable_importance( # noqa: PLR0915 PLR0912
957962

958963
def plot_variable_importance(
959964
vi_results: dict,
960-
X: npt.NDArray[np.float64],
961965
labels=None,
962966
figsize=None,
963967
plot_kwargs: Optional[Dict[str, Any]] = None,
@@ -1008,19 +1012,13 @@ def plot_variable_importance(
10081012
if figsize is None:
10091013
figsize = (8, 3)
10101014

1011-
if hasattr(X, "columns") and hasattr(X, "to_numpy"):
1012-
labels = X.columns
1013-
X = X.to_numpy()
1014-
10151015
if ax is None:
10161016
_, ax = plt.subplots(1, 1, figsize=figsize)
10171017

10181018
if labels is None:
1019-
labels = np.arange(n_vars).astype(str)
1020-
else:
1021-
labels = np.asarray(labels)
1019+
labels = vi_results["labels"]
10221020

1023-
new_labels = ["+ " + ele if index != 0 else ele for index, ele in enumerate(labels[indices])]
1021+
labels = ["+ " + ele if index != 0 else ele for index, ele in enumerate(labels)]
10241022

10251023
r_2_ref = np.array([pearsonr2(preds_all[j], preds_all[j + 1]) for j in range(samples - 1)])
10261024

@@ -1048,7 +1046,7 @@ def plot_variable_importance(
10481046
)
10491047
ax.set_xticks(
10501048
ticks,
1051-
new_labels,
1049+
labels,
10521050
rotation=plot_kwargs.get("rotation", 0),
10531051
)
10541052
ax.set_ylabel("R²", rotation=0, labelpad=12)
@@ -1058,25 +1056,80 @@ def plot_variable_importance(
10581056
return ax
10591057

10601058

1061-
def plot_scatter_submodels(vi_results, func=None, grid="long", axes=None):
1059+
def plot_scatter_submodels(
1060+
vi_results: dict,
1061+
func: Optional[Callable] = None,
1062+
grid: str = "long",
1063+
labels=None,
1064+
figsize: Optional[Tuple[float, float]] = None,
1065+
plot_kwargs: Optional[Dict[str, Any]] = None,
1066+
axes: Optional[plt.Axes] = None,
1067+
):
1068+
"""
1069+
Plot submodel's predictions against reference-model's predictions.
1070+
1071+
Parameters
1072+
----------
1073+
vi_results: Dictionary
1074+
Dictionary computed with `compute_variable_importance`
1075+
func : Optional[Callable], by default None.
1076+
Arbitrary function to apply to the predictions. Defaults to the identity function.
1077+
grid : str or tuple
1078+
How to arrange the subplots. Defaults to "long", one subplot below the other.
1079+
Other options are "wide", one subplot next to each other or a tuple indicating the number
1080+
of rows and columns.
1081+
labels : Optional[List[str]]
1082+
List of the names of the covariates.
1083+
plot_kwargs : dict
1084+
Additional keyword arguments for the plot. Defaults to None.
1085+
Valid keys are:
1086+
- color_ref: matplotlib valid color for the 45 degree line
1087+
- color_scatter: matplotlib valid color for the scatter plot
1088+
axes : axes
1089+
Matplotlib axes.
1090+
1091+
Returns
1092+
-------
1093+
axes: matplotlib axes
1094+
"""
10621095
indices = vi_results["indices"]
10631096
preds = vi_results["preds"]
10641097
preds_all = vi_results["preds_all"]
10651098

10661099
if axes is None:
1067-
_, axes = _get_axes(grid, len(indices), False, True, None)
1100+
_, axes = _get_axes(grid, len(indices), True, True, figsize)
1101+
1102+
if plot_kwargs is None:
1103+
plot_kwargs = {}
1104+
1105+
if labels is None:
1106+
labels = vi_results["labels"]
1107+
1108+
labels = ["+ " + ele if index != 0 else ele for index, ele in enumerate(labels)]
10681109

1069-
func = None
10701110
if func is not None:
10711111
preds = func(preds)
10721112
preds_all = func(preds_all)
10731113

10741114
min_ = min(np.min(preds), np.min(preds_all))
10751115
max_ = max(np.max(preds), np.max(preds_all))
10761116

1077-
for pred, ax in zip(preds, axes.ravel()):
1078-
ax.plot(pred, preds_all, ".", color="C0", alpha=0.1)
1079-
ax.axline([min_, min_], [max_, max_], color="0.5")
1117+
for pred, x_label, ax in zip(preds, labels, axes.ravel()):
1118+
ax.plot(
1119+
pred,
1120+
preds_all,
1121+
marker=plot_kwargs.get("marker_scatter", "."),
1122+
ls="",
1123+
color=plot_kwargs.get("color_scatter", "C0"),
1124+
alpha=plot_kwargs.get("alpha_scatter", 0.1),
1125+
)
1126+
ax.set_xlabel(x_label)
1127+
ax.axline(
1128+
[min_, min_],
1129+
[max_, max_],
1130+
color=plot_kwargs.get("color_ref", "0.5"),
1131+
ls=plot_kwargs.get("ls_ref", "--"),
1132+
)
10801133

10811134

10821135
def generate_sequences(n_vars, i_var, include):

tests/test_bart.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -193,8 +193,8 @@ def test_vi(self, kwargs):
193193
vi_results = pmb.compute_variable_importance(
194194
self.idata, bartrv=self.mu, X=self.X, samples=samples
195195
)
196-
pmb.plot_variable_importance(vi_results, X=self.X, **kwargs)
197-
pmb.plot_scatter_submodels(vi_results)
196+
pmb.plot_variable_importance(vi_results, **kwargs)
197+
pmb.plot_scatter_submodels(vi_results, **kwargs)
198198

199199
def test_pdp_pandas_labels(self):
200200
pd = pytest.importorskip("pandas")

0 commit comments

Comments
 (0)