Skip to content

Commit b20d074

Browse files
authored
add submodels arguments to plot subsets (#200)
1 parent 9ec4de8 commit b20d074

File tree

1 file changed

+47
-29
lines changed

1 file changed

+47
-29
lines changed

pymc_bart/utils.py

Lines changed: 47 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -705,7 +705,7 @@ def plot_variable_inclusion(idata, X, labels=None, figsize=None, plot_kwargs=Non
705705
706706
Parameters
707707
----------
708-
idata: InferenceData
708+
idata : InferenceData
709709
InferenceData containing a collection of BART_trees in sample_stats group
710710
X : npt.NDArray[np.float64]
711711
The covariate matrix.
@@ -784,7 +784,7 @@ def compute_variable_importance( # noqa: PLR0915 PLR0912
784784
785785
Parameters
786786
----------
787-
idata: InferenceData
787+
idata : InferenceData
788788
InferenceData containing a collection of BART_trees in sample_stats group
789789
bartrv : BART Random Variable
790790
BART variable once the model that include it has been fitted.
@@ -949,8 +949,10 @@ def compute_variable_importance( # noqa: PLR0915 PLR0912
949949

950950
indices = least_important_vars[::-1]
951951

952+
labels = np.array(["+ " + ele if index != 0 else ele for index, ele in enumerate(labels)])
953+
952954
vi_results = {
953-
"indices": indices,
955+
"indices": np.asarray(indices),
954956
"labels": labels[indices],
955957
"r2_mean": r2_mean,
956958
"r2_hdi": r2_hdi,
@@ -962,8 +964,9 @@ def compute_variable_importance( # noqa: PLR0915 PLR0912
962964

963965
def plot_variable_importance(
964966
vi_results: dict,
965-
labels=None,
966-
figsize=None,
967+
submodels: Optional[Union[list[int], np.ndarray, tuple[int, ...]]] = None,
968+
labels: Optional[list[str]] = None,
969+
figsize: Optional[tuple[float, float]] = None,
967970
plot_kwargs: Optional[dict[str, Any]] = None,
968971
ax: Optional[plt.Axes] = None,
969972
):
@@ -974,8 +977,11 @@ def plot_variable_importance(
974977
----------
975978
vi_results: Dictionary
976979
Dictionary computed with `compute_variable_importance`
977-
X : npt.NDArray[np.float64]
978-
The covariate matrix.
980+
submodels : Optional[Union[list[int], np.ndarray]]
981+
List of the indices of the submodels to plot. Defaults to None, all variables are ploted.
982+
The indices correspond to order computed by `compute_variable_importance`.
983+
For example `submodels=[0,1]` will plot the two most important variables.
984+
`submodels=[1,0]` is equivalent as values are sorted before use.
979985
labels : Optional[list[str]]
980986
List of the names of the covariates. If X is a DataFrame the names of the covariables will
981987
be taken from it and this argument will be ignored.
@@ -995,11 +1001,15 @@ def plot_variable_importance(
9951001
-------
9961002
axes: matplotlib axes
9971003
"""
1004+
if submodels is None:
1005+
submodels = np.sort(vi_results["indices"])
1006+
else:
1007+
submodels = np.sort(submodels)
9981008

999-
indices = vi_results["indices"]
1000-
r2_mean = vi_results["r2_mean"]
1001-
r2_hdi = vi_results["r2_hdi"]
1002-
preds = vi_results["preds"]
1009+
indices = vi_results["indices"][submodels]
1010+
r2_mean = vi_results["r2_mean"][submodels]
1011+
r2_hdi = vi_results["r2_hdi"][submodels]
1012+
preds = vi_results["preds"][submodels]
10031013
preds_all = vi_results["preds_all"]
10041014
samples = preds.shape[1]
10051015

@@ -1016,9 +1026,7 @@ def plot_variable_importance(
10161026
_, ax = plt.subplots(1, 1, figsize=figsize)
10171027

10181028
if labels is None:
1019-
labels = vi_results["labels"]
1020-
1021-
labels = ["+ " + ele if index != 0 else ele for index, ele in enumerate(labels)]
1029+
labels = vi_results["labels"][submodels]
10221030

10231031
r_2_ref = np.array([pearsonr2(preds_all[j], preds_all[j + 1]) for j in range(samples - 1)])
10241032

@@ -1059,21 +1067,27 @@ def plot_variable_importance(
10591067
def plot_scatter_submodels(
10601068
vi_results: dict,
10611069
func: Optional[Callable] = None,
1070+
submodels: Optional[Union[list[int], np.ndarray]] = None,
10621071
grid: str = "long",
1063-
labels=None,
1072+
labels: Optional[list[str]] = None,
10641073
figsize: Optional[tuple[float, float]] = None,
10651074
plot_kwargs: Optional[dict[str, Any]] = None,
1066-
axes: Optional[plt.Axes] = None,
1067-
):
1075+
ax: Optional[plt.Axes] = None,
1076+
) -> list[plt.Axes]:
10681077
"""
10691078
Plot submodel's predictions against reference-model's predictions.
10701079
10711080
Parameters
10721081
----------
1073-
vi_results: Dictionary
1082+
vi_results : Dictionary
10741083
Dictionary computed with `compute_variable_importance`
10751084
func : Optional[Callable], by default None.
10761085
Arbitrary function to apply to the predictions. Defaults to the identity function.
1086+
submodels : Optional[Union[list[int], np.ndarray]]
1087+
List of the indices of the submodels to plot. Defaults to None, all variables are ploted.
1088+
The indices correspond to order computed by `compute_variable_importance`.
1089+
For example `submodels=[0,1]` will plot the two most important variables.
1090+
`submodels=[1,0]` is equivalent as values are sorted before use.
10771091
grid : str or tuple
10781092
How to arrange the subplots. Defaults to "long", one subplot below the other.
10791093
Other options are "wide", one subplot next to each other or a tuple indicating the number
@@ -1092,20 +1106,23 @@ def plot_scatter_submodels(
10921106
-------
10931107
axes: matplotlib axes
10941108
"""
1095-
indices = vi_results["indices"]
1096-
preds = vi_results["preds"]
1109+
if submodels is None:
1110+
submodels = np.sort(vi_results["indices"])
1111+
else:
1112+
submodels = np.sort(submodels)
1113+
1114+
indices = vi_results["indices"][submodels]
1115+
preds = vi_results["preds"][submodels]
10971116
preds_all = vi_results["preds_all"]
10981117

1099-
if axes is None:
1100-
_, axes = _get_axes(grid, len(indices), True, True, figsize)
1118+
if ax is None:
1119+
_, ax = _get_axes(grid, len(indices), True, True, figsize)
11011120

11021121
if plot_kwargs is None:
11031122
plot_kwargs = {}
11041123

11051124
if labels is None:
1106-
labels = vi_results["labels"]
1107-
1108-
labels = ["+ " + ele if index != 0 else ele for index, ele in enumerate(labels)]
1125+
labels = vi_results["labels"][submodels]
11091126

11101127
if func is not None:
11111128
preds = func(preds)
@@ -1114,22 +1131,23 @@ def plot_scatter_submodels(
11141131
min_ = min(np.min(preds), np.min(preds_all))
11151132
max_ = max(np.max(preds), np.max(preds_all))
11161133

1117-
for pred, x_label, ax in zip(preds, labels, axes.ravel()):
1118-
ax.plot(
1134+
for pred, x_label, axi in zip(preds, labels, ax.ravel()):
1135+
axi.plot(
11191136
pred,
11201137
preds_all,
11211138
marker=plot_kwargs.get("marker_scatter", "."),
11221139
ls="",
11231140
color=plot_kwargs.get("color_scatter", "C0"),
11241141
alpha=plot_kwargs.get("alpha_scatter", 0.1),
11251142
)
1126-
ax.set_xlabel(x_label)
1127-
ax.axline(
1143+
axi.set_xlabel(x_label)
1144+
axi.axline(
11281145
[min_, min_],
11291146
[max_, max_],
11301147
color=plot_kwargs.get("color_ref", "0.5"),
11311148
ls=plot_kwargs.get("ls_ref", "--"),
11321149
)
1150+
return ax
11331151

11341152

11351153
def generate_sequences(n_vars, i_var, include):

0 commit comments

Comments
 (0)