Skip to content

Commit a2e6a1d

Browse files
committed
Merge remote-tracking branch 'refs/remotes/origin/main'
2 parents 794726f + c38aec8 commit a2e6a1d

File tree

2 files changed

+47
-14
lines changed

2 files changed

+47
-14
lines changed

.pre-commit-config.yaml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -12,14 +12,14 @@ ci:
1212

1313
repos:
1414
- repo: https://github.com/astral-sh/ruff-pre-commit
15-
rev: v0.6.1
15+
rev: v0.6.3
1616
hooks:
1717
- id: ruff
1818
args: ["--fix", "--output-format=full"]
1919
- id: ruff-format
2020
args: ["--line-length=100"]
2121
- repo: https://github.com/pre-commit/mirrors-mypy
22-
rev: v1.11.1
22+
rev: v1.11.2
2323
hooks:
2424
- id: mypy
2525
args: [--ignore-missing-imports]

pymc_bart/utils.py

Lines changed: 45 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -8,10 +8,11 @@
88
import numpy as np
99
import numpy.typing as npt
1010
import pytensor.tensor as pt
11+
from numba import jit
1112
from pytensor.tensor.variable import Variable
1213
from scipy.interpolate import griddata
1314
from scipy.signal import savgol_filter
14-
from scipy.stats import norm, pearsonr
15+
from scipy.stats import norm
1516

1617
from .tree import Tree
1718

@@ -700,8 +701,9 @@ def plot_variable_importance( # noqa: PLR0915
700701
method: str = "VI",
701702
figsize: Optional[Tuple[float, float]] = None,
702703
xlabel_angle: float = 0,
703-
samples: int = 100,
704+
samples: int = 50,
704705
random_seed: Optional[int] = None,
706+
plot_kwargs: Optional[Dict[str, Any]] = None,
705707
ax: Optional[plt.Axes] = None,
706708
) -> Tuple[List[int], Union[List[plt.Axes], Any]]:
707709
"""
@@ -733,6 +735,14 @@ def plot_variable_importance( # noqa: PLR0915
733735
Number of predictions used to compute correlation for subsets of variables. Defaults to 100
734736
random_seed : Optional[int]
735737
random_seed used to sample from the posterior. Defaults to None.
738+
plot_kwargs : dict
739+
Additional keyword arguments for the plot. Defaults to None.
740+
Valid keys are:
741+
- color_r2: matplotlib valid color for error bars
742+
- marker_r2: matplotlib valid marker for the mean R squared
743+
- marker_fc_r2: matplotlib valid marker face color for the mean R squared
744+
- ls_ref: matplotlib valid linestyle for the reference line
745+
- color_ref: matplotlib valid color for the reference line
736746
ax : axes
737747
Matplotlib axes.
738748
@@ -745,6 +755,9 @@ def plot_variable_importance( # noqa: PLR0915
745755

746756
all_trees = bartrv.owner.op.all_trees
747757

758+
if plot_kwargs is None:
759+
plot_kwargs = {}
760+
748761
if bartrv.ndim == 1: # type: ignore
749762
shape = 1
750763
else:
@@ -773,6 +786,10 @@ def plot_variable_importance( # noqa: PLR0915
773786
all_trees, X=X, rng=rng, size=samples, excluded=None, shape=shape
774787
)
775788

789+
r_2_ref = np.array(
790+
[pearsonr2(predicted_all[j], predicted_all[j + 1]) for j in range(samples - 1)]
791+
)
792+
776793
if method == "VI":
777794
idxs = np.argsort(
778795
idata["sample_stats"]["variable_inclusion"].mean(("chain", "draw")).values
@@ -794,10 +811,7 @@ def plot_variable_importance( # noqa: PLR0915
794811
shape=shape,
795812
)
796813
r_2 = np.array(
797-
[
798-
pearsonr(predicted_all[j].flatten(), predicted_subset[j].flatten())[0] ** 2
799-
for j in range(samples)
800-
]
814+
[pearsonr2(predicted_all[j], predicted_subset[j]) for j in range(samples)]
801815
)
802816
r2_mean[idx] = np.mean(r_2)
803817
r2_hdi[idx] = az.hdi(r_2)
@@ -833,10 +847,7 @@ def plot_variable_importance( # noqa: PLR0915
833847
# Calculate Pearson correlation for each sample and find the mean
834848
r_2 = np.zeros(samples)
835849
for j in range(samples):
836-
r_2[j] = (
837-
(pearsonr(predicted_all[j].flatten(), predicted_subset[j].flatten())[0])
838-
** 2
839-
)
850+
r_2[j] = pearsonr2(predicted_all[j], predicted_subset[j])
840851
mean_r_2 = np.mean(r_2, dtype=float)
841852
# Identify the least important combination of variables
842853
# based on the maximum mean squared Pearson correlation
@@ -872,9 +883,21 @@ def plot_variable_importance( # noqa: PLR0915
872883
ticks,
873884
r2_mean,
874885
np.array((r2_yerr_min, r2_yerr_max)),
875-
color="C0",
886+
color=plot_kwargs.get("color_r2", "k"),
887+
fmt=plot_kwargs.get("marker_r2", "o"),
888+
mfc=plot_kwargs.get("marker_fc_r2", "white"),
889+
)
890+
ax.axhline(
891+
np.mean(r_2_ref),
892+
ls=plot_kwargs.get("ls_ref", "--"),
893+
color=plot_kwargs.get("color_ref", "grey"),
894+
)
895+
ax.fill_between(
896+
[-0.5, n_vars - 0.5],
897+
*az.hdi(r_2_ref),
898+
alpha=0.1,
899+
color=plot_kwargs.get("color_ref", "grey"),
876900
)
877-
ax.axhline(r2_mean[-1], ls="--", color="0.5")
878901
ax.set_xticks(ticks, new_labels, rotation=xlabel_angle)
879902
ax.set_ylabel("R²", rotation=0, labelpad=12)
880903
ax.set_ylim(0, 1)
@@ -890,3 +913,13 @@ def generate_sequences(n_vars, i_var, include):
890913
else:
891914
sequences = [()]
892915
return sequences
916+
917+
918+
@jit(nopython=True)
919+
def pearsonr2(A, B):
920+
"""Compute the squared Pearson correlation coefficient"""
921+
A = A.flatten()
922+
B = B.flatten()
923+
am = A - np.mean(A)
924+
bm = B - np.mean(B)
925+
return (am @ bm) ** 2 / (np.sum(am**2) * np.sum(bm**2))

0 commit comments

Comments
 (0)