Skip to content

Commit a952dc7

Browse files
authored
Merge pull request #56 from pymc-labs/bayesian-r2
#43 use Bayesian R2 measures for pymc models + rerun notebook
2 parents 2a201ed + 68a717f commit a952dc7

File tree

6 files changed

+5141
-4853
lines changed

6 files changed

+5141
-4853
lines changed

causalpy/pymc_experiments.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -95,7 +95,9 @@ def plot(self):
9595
self.datapost.index, self.post_pred["posterior_predictive"].y_hat, ax=ax[0]
9696
)
9797
ax[0].plot(self.datapost.index, self.post_y, "k.")
98-
ax[0].set(title=f"$R^2$ on pre-intervention data = {self.score:.3f}")
98+
ax[0].set(
99+
title=f"Pre-intervention Bayesian $R^2$: {self.score.r2:.3f} (std = {self.score.r2_std:.3f})"
100+
)
99101

100102
plot_xY(self.datapre.index, self.pre_impact, ax=ax[1])
101103
plot_xY(self.datapost.index, self.post_impact, ax=ax[1])
@@ -393,7 +395,7 @@ def plot(self):
393395
ax=ax,
394396
)
395397
# create strings to compose title
396-
r2 = f"$R^2$ on all data = {self.score:.3f}"
398+
r2 = f"Bayesian $R^2$ on all data = {self.score.r2:.3f} (std = {self.score.r2_std:.3f})"
397399
percentiles = self.discontinuity_at_threshold.quantile([0.03, 1 - 0.03]).values
398400
ci = r"$CI_{94\%}$" + f"[{percentiles[0]:.2f}, {percentiles[1]:.2f}]"
399401
discon = f"Discontinuity at threshold = {self.discontinuity_at_threshold.mean():.2f}, "

causalpy/pymc_models.py

Lines changed: 15 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
import arviz as az
22
import numpy as np
33
import pymc as pm
4-
from sklearn.metrics import r2_score
4+
from arviz import r2_score
55

66

77
class ModelBuilder(pm.Model):
@@ -39,12 +39,21 @@ def predict(self, X):
3939
return post_pred
4040

4141
def score(self, X, y):
42-
"""Score the predictions :math:`R^2` given inputs ``X`` and outputs ``y``."""
42+
"""Score the Bayesian :math:`R^2` given inputs ``X`` and outputs ``y``.
43+
44+
.. caution::
45+
46+
The Bayesian :math:`R^2` is not the same as the traditional coefficient of determination, https://en.wikipedia.org/wiki/Coefficient_of_determination.
47+
48+
"""
4349
yhat = self.predict(X)
44-
yhat = az.extract(yhat, group="posterior_predictive", var_names="y_hat").mean(
45-
dim="sample"
46-
)
47-
return r2_score(y, yhat)
50+
yhat = az.extract(
51+
yhat, group="posterior_predictive", var_names="y_hat"
52+
).T.values
53+
# Note: First argument must be a 1D array
54+
return r2_score(y.flatten(), yhat)
55+
56+
# .stack(sample=("chain", "draw")
4857

4958

5059
class WeightedSumFitter(ModelBuilder):

docs/notebooks/pymc_demos.ipynb

Lines changed: 31 additions & 97 deletions
Large diffs are not rendered by default.

0 commit comments

Comments
 (0)